diff options
Diffstat (limited to 'pkg/tcpip')
182 files changed, 6855 insertions, 68234 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD deleted file mode 100644 index f979d22f0..000000000 --- a/pkg/tcpip/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "sock_err_list", - out = "sock_err_list.go", - package = "tcpip", - prefix = "sockError", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*SockError", - "Linker": "*SockError", - }, -) - -go_library( - name = "tcpip", - srcs = [ - "errors.go", - "sock_err_list.go", - "socketops.go", - "tcpip.go", - "time_unsafe.go", - "timer.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) - -go_test( - name = "tcpip_test", - size = "small", - srcs = ["tcpip_test.go"], - library = ":tcpip", - deps = ["@com_github_google_go_cmp//cmp:go_default_library"], -) - -go_test( - name = "tcpip_x_test", - size = "small", - srcs = ["timer_test.go"], - deps = [":tcpip"], -) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD deleted file mode 100644 index a984f1712..000000000 --- a/pkg/tcpip/adapters/gonet/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gonet", - srcs = ["gonet.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) - -go_test( - name = "gonet_test", - size = "small", - srcs = ["gonet_test.go"], - library = ":gonet", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@org_golang_x_net//nettest:go_default_library", - ], -) diff --git a/pkg/tcpip/adapters/gonet/gonet_state_autogen.go b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go new file mode 100644 index 000000000..7a5c5419e --- /dev/null +++ b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package gonet diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go deleted file mode 100644 index 2b3ea4bdf..000000000 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ /dev/null @@ -1,725 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gonet - -import ( - "context" - "fmt" - "io" - "net" - "reflect" - "strings" - "testing" - "time" - - "golang.org/x/net/nettest" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - NICID = 1 -) - -func TestTimeouts(t *testing.T) { - nc := NewTCPConn(nil, nil) - dlfs := []struct { - name string - f func(time.Time) error - }{ - {"SetDeadline", nc.SetDeadline}, - {"SetReadDeadline", nc.SetReadDeadline}, - {"SetWriteDeadline", nc.SetWriteDeadline}, - } - - for _, dlf := range dlfs { - if err := dlf.f(time.Time{}); err != nil { - t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil) - } - } -} - -func newLoopbackStack() (*stack.Stack, tcpip.Error) { - // Create the stack and add a NIC. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, - }) - - if err := s.CreateNIC(NICID, loopback.New()); err != nil { - return nil, err - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - // IPv4 - { - Destination: header.IPv4EmptySubnet, - NIC: NICID, - }, - - // IPv6 - { - Destination: header.IPv6EmptySubnet, - NIC: NICID, - }, - }) - - return s, nil -} - -type testConnection struct { - wq *waiter.Queue - e *waiter.Entry - ch chan struct{} - ep tcpip.Endpoint -} - -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) { - wq := &waiter.Queue{} - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - return nil, err - } - - entry, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&entry, waiter.EventOut) - - err = ep.Connect(addr) - if _, ok := err.(*tcpip.ErrConnectStarted); ok { - <-ch - err = ep.LastError() - } - if err != nil { - return nil, err - } - - wq.EventUnregister(&entry) - wq.EventRegister(&entry, waiter.EventIn) - - return &testConnection{wq, &entry, ch, ep}, nil -} - -func (c *testConnection) close() { - c.wq.EventUnregister(c.e) - c.ep.Close() -} - -// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock. -func TestCloseReader(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Errorf("l.Accept() = %v", err) - // Cannot call Fatalf in goroutine. Just return from the goroutine. - return - } - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - if n != 0 || err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when -// using tcp.Forwarder. -func TestCloseReaderWithForwarder(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - done := make(chan struct{}) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) - } - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestCloseRead(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - _, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - // Endpoint will be closed in deferred s.Close (above). - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewTCPConn(tc.wq, tc.ep) - - if err := c.CloseRead(); err != nil { - t.Errorf("c.CloseRead() = %v", err) - } - - buf := make([]byte, 256) - if n, err := c.Read(buf); err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err) - } - - if n, err := c.Write([]byte("abc123")); n != 6 || err != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err) - } -} - -func TestCloseWrite(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - n, e := c.Read(make([]byte, 256)) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e) - } - - if n, e = c.Write([]byte("abc123")); n != 6 || e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewTCPConn(tc.wq, tc.ep) - - if err := c.CloseWrite(); err != nil { - t.Errorf("c.CloseWrite() = %v", err) - } - - buf := make([]byte, 256) - n, err := c.Read(buf) - if err != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err) - } - - n, err = c.Write([]byte("abc123")) - got, ok := err.(*net.OpError) - want := "endpoint is closed for send" - if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) { - t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want) - } -} - -func TestUDPForwarder(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - done := make(chan struct{}) - fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil { - t.Errorf("c.Read() = %v", e) - } - - if _, e := c.Write(buf[:n]); e != nil { - t.Errorf("c.Write() = %v", e) - } - }) - s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - sent := "abc123" - sendAddr := fullToUDPAddr(addr1) - if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - - buf := make([]byte, 256) - n, recvAddr, err := c2.ReadFrom(buf) - if err != nil || recvAddr.String() != sendAddr.String() { - t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil) - } -} - -// TestDeadlineChange tests that changing the deadline affects currently blocked reads. -func TestDeadlineChange(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Errorf("l.Accept() = %v", err) - // Cannot call Fatalf in goroutine. Just return from the goroutine. - return - } - - c.SetDeadline(time.Now().Add(time.Minute)) - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.SetDeadline(time.Now().Add(time.Millisecond * 10)) - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := "i/o timeout" - if n != 0 || !ok || got.Err == nil || got.Err.Error() != want { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(time.Millisecond * 500): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - sendAddr := fullToUDPAddr(addr2) - if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, recvAddr, err := c2.ReadFrom(recv) - if err != nil || n != len(recv) { - t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) { - t.Errorf("got recvAddr = %v, want = %v", recvAddr, want) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func TestConnectedPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, err := c1.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func makePipe() (c1, c2 net.Conn, stop func(), err error) { - s, e := newLoopbackStack() - if e != nil { - return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) - } - - c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) - } - - c2, err = l.Accept() - if err != nil { - l.Close() - c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) - } - - stop = func() { - c1.Close() - c2.Close() - s.Close() - s.Wait() - } - - if err := l.Close(); err != nil { - stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) - } - - return c1, c2, stop, nil -} - -func TestTCPConnTransfer(t *testing.T) { - c1, c2, _, err := makePipe() - if err != nil { - t.Fatal(err) - } - defer func() { - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } - }() - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - const sent = "abc123" - - tests := []struct { - name string - c1 net.Conn - c2 net.Conn - }{ - {"connected to accepted", c1, c2}, - {"accepted to connected", c2, c1}, - } - - for _, test := range tests { - if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil) - continue - } - - recv := make([]byte, len(sent)) - n, err := test.c2.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil) - continue - } - - if recv := string(recv); recv != sent { - t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent) - } - } -} - -func TestTCPDialError(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - - switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) { - case *net.OpError: - if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() { - t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{}) - } - default: - t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{}) - } -} - -func TestDialContextTCPCanceled(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled) - } -} - -func TestDialContextTCPTimeout(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - time.Sleep(time.Second) - r.Complete(true) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - ctx := context.Background() - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) - defer cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded) - } -} - -func TestNetTest(t *testing.T) { - nettest.TestConn(t, makePipe) -} diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD deleted file mode 100644 index 23aa0ad05..000000000 --- a/pkg/tcpip/buffer/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "buffer", - srcs = [ - "prependable.go", - "view.go", - "view_unsafe.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "buffer_x_test", - size = "small", - srcs = [ - "view_test.go", - ], - deps = [ - ":buffer", - "//pkg/tcpip", - ], -) diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go new file mode 100644 index 000000000..9f0e96ed1 --- /dev/null +++ b/pkg/tcpip/buffer/buffer_state_autogen.go @@ -0,0 +1,37 @@ +// automatically generated by stateify. + +package buffer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (vv *VectorisedView) StateTypeName() string { + return "pkg/tcpip/buffer.VectorisedView" +} + +func (vv *VectorisedView) StateFields() []string { + return []string{ + "views", + "size", + } +} + +func (vv *VectorisedView) beforeSave() {} + +func (vv *VectorisedView) StateSave(stateSinkObject state.Sink) { + vv.beforeSave() + stateSinkObject.Save(0, &vv.views) + stateSinkObject.Save(1, &vv.size) +} + +func (vv *VectorisedView) afterLoad() {} + +func (vv *VectorisedView) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &vv.views) + stateSourceObject.Load(1, &vv.size) +} + +func init() { + state.Register((*VectorisedView)(nil)) +} diff --git a/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go b/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go new file mode 100644 index 000000000..5a5c40722 --- /dev/null +++ b/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package buffer diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go deleted file mode 100644 index 78b2faa26..000000000 --- a/pkg/tcpip/buffer/view_test.go +++ /dev/null @@ -1,593 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package buffer_test contains tests for the buffer.VectorisedView type. -package buffer_test - -import ( - "bytes" - "io" - "reflect" - "testing" - "unsafe" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// copy returns a deep-copy of the vectorised view. -func copyVV(vv buffer.VectorisedView) buffer.VectorisedView { - views := make([]buffer.View, 0, len(vv.Views())) - for _, v := range vv.Views() { - views = append(views, append(buffer.View(nil), v...)) - } - return buffer.NewVectorisedView(vv.Size(), views) -} - -// vv is an helper to build buffer.VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -var capLengthTestCases = []struct { - comment string - in buffer.VectorisedView - length int - want buffer.VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Case spanning across two Views", - in: vv(4, "123", "4"), - length: 2, - want: vv(2, "12"), - }, - { - comment: "Corner case with negative length", - in: vv(1, "1"), - length: -1, - want: vv(0), - }, - { - comment: "Corner case with length = 0", - in: vv(3, "12", "3"), - length: 0, - want: vv(0), - }, - { - comment: "Corner case with length = size", - in: vv(1, "1"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Corner case with length > size", - in: vv(1, "1"), - length: 2, - want: vv(1, "1"), - }, -} - -func TestCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - orig := copyVV(c.in) - c.in.CapLength(c.length) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v", - c.comment, c.length, orig, c.in, c.want) - } - } -} - -var trimFrontTestCases = []struct { - comment string - in buffer.VectorisedView - count int - want buffer.VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case where we trim an entire View", - in: vv(2, "1", "2"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: vv(1, "3"), - }, - { - comment: "Corner case with negative count", - in: vv(1, "1"), - count: -1, - want: vv(1, "1"), - }, - { - comment: " Corner case with count = 0", - in: vv(1, "1"), - count: 0, - want: vv(1, "1"), - }, - { - comment: "Corner case with count = size", - in: vv(1, "1"), - count: 1, - want: vv(0), - }, - { - comment: "Corner case with count > size", - in: vv(1, "1"), - count: 2, - want: vv(0), - }, -} - -func TestTrimFront(t *testing.T) { - for _, c := range trimFrontTestCases { - orig := copyVV(c.in) - c.in.TrimFront(c.count) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v", - c.comment, c.count, orig, c.in, c.want) - } - } -} - -var toViewCases = []struct { - comment string - in buffer.VectorisedView - want buffer.View -}{ - { - comment: "Simple case", - in: vv(2, "12"), - want: []byte("12"), - }, - { - comment: "Case with multiple views", - in: vv(2, "1", "2"), - want: []byte("12"), - }, - { - comment: "Empty case", - in: vv(0), - want: []byte(""), - }, -} - -func TestToView(t *testing.T) { - for _, c := range toViewCases { - got := c.in.ToView() - if !reflect.DeepEqual(got, c.want) { - t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v", - c.comment, c.in, got, c.want) - } - } -} - -var toCloneCases = []struct { - comment string - inView buffer.VectorisedView - inBuffer []buffer.View -}{ - { - comment: "Simple case", - inView: vv(1, "1"), - inBuffer: make([]buffer.View, 1), - }, - { - comment: "Case with multiple views", - inView: vv(2, "1", "2"), - inBuffer: make([]buffer.View, 2), - }, - { - comment: "Case with buffer too small", - inView: vv(2, "1", "2"), - inBuffer: make([]buffer.View, 1), - }, - { - comment: "Case with buffer larger than needed", - inView: vv(1, "1"), - inBuffer: make([]buffer.View, 2), - }, - { - comment: "Case with nil buffer", - inView: vv(1, "1"), - inBuffer: nil, - }, -} - -func TestToClone(t *testing.T) { - for _, c := range toCloneCases { - t.Run(c.comment, func(t *testing.T) { - got := c.inView.Clone(c.inBuffer) - if !reflect.DeepEqual(got, c.inView) { - t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v", - c.inView, c.inBuffer, got, c.inView) - } - }) - } -} - -type readToTestCases struct { - comment string - vv buffer.VectorisedView - bytesToRead int - wantBytes string - leftVV buffer.VectorisedView -} - -func createReadToTestCases() []readToTestCases { - return []readToTestCases{ - { - comment: "large VV, short read", - vv: vv(30, "012345678901234567890123456789"), - bytesToRead: 10, - wantBytes: "0123456789", - leftVV: vv(20, "01234567890123456789"), - }, - { - comment: "largeVV, multiple views, short read", - vv: vv(13, "123", "345", "567", "8910"), - bytesToRead: 6, - wantBytes: "123345", - leftVV: vv(7, "567", "8910"), - }, - { - comment: "smallVV (multiple views), large read", - vv: vv(3, "1", "2", "3"), - bytesToRead: 10, - wantBytes: "123", - leftVV: vv(0, ""), - }, - { - comment: "smallVV (single view), large read", - vv: vv(1, "1"), - bytesToRead: 10, - wantBytes: "1", - leftVV: vv(0, ""), - }, - { - comment: "emptyVV, large read", - vv: vv(0, ""), - bytesToRead: 10, - wantBytes: "", - leftVV: vv(0, ""), - }, - } -} - -func TestVVReadToVV(t *testing.T) { - for _, tc := range createReadToTestCases() { - t.Run(tc.comment, func(t *testing.T) { - var readTo buffer.VectorisedView - inSize := tc.vv.Size() - copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) - if got, want := copied, len(tc.wantBytes); got != want { - t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc) - } - if got, want := string(readTo.ToView()), tc.wantBytes; got != want { - t.Errorf("unexpected content in readTo got: %s, want: %s", got, want) - } - if got, want := tc.vv.Size(), inSize-copied; got != want { - t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) - } - if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { - t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want) - } - }) - } -} - -func TestVVReadTo(t *testing.T) { - for _, tc := range createReadToTestCases() { - t.Run(tc.comment, func(t *testing.T) { - b := make([]byte, tc.bytesToRead) - dst := tcpip.SliceWriter(b) - origSize := tc.vv.Size() - copied, err := tc.vv.ReadTo(&dst, false /* peek */) - if err != nil && err != io.ErrShortWrite { - t.Errorf("got ReadTo(&dst, false) = (_, %s); want nil or io.ErrShortWrite", err) - } - if got, want := copied, len(tc.wantBytes); got != want { - t.Errorf("got ReadTo(&dst, false) = (%d, _); want %d", got, want) - } - if got, want := string(b[:copied]), tc.wantBytes; got != want { - t.Errorf("got dst = %q, want %q", got, want) - } - if got, want := tc.vv.Size(), origSize-copied; got != want { - t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) - } - if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { - t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) - } - }) - } -} - -func TestVVReadToPeek(t *testing.T) { - for _, tc := range createReadToTestCases() { - t.Run(tc.comment, func(t *testing.T) { - b := make([]byte, tc.bytesToRead) - dst := tcpip.SliceWriter(b) - origSize := tc.vv.Size() - origData := string(tc.vv.ToView()) - copied, err := tc.vv.ReadTo(&dst, true /* peek */) - if err != nil && err != io.ErrShortWrite { - t.Errorf("got ReadTo(&dst, true) = (_, %s); want nil or io.ErrShortWrite", err) - } - if got, want := copied, len(tc.wantBytes); got != want { - t.Errorf("got ReadTo(&dst, true) = (%d, _); want %d", got, want) - } - if got, want := string(b[:copied]), tc.wantBytes; got != want { - t.Errorf("got dst = %q, want %q", got, want) - } - // Expect tc.vv is unchanged. - if got, want := tc.vv.Size(), origSize; got != want { - t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) - } - if got, want := string(tc.vv.ToView()), origData; got != want { - t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) - } - }) - } -} - -func TestVVRead(t *testing.T) { - testCases := []struct { - comment string - vv buffer.VectorisedView - bytesToRead int - readBytes string - leftBytes string - wantError bool - }{ - { - comment: "large VV, short read", - vv: vv(30, "012345678901234567890123456789"), - bytesToRead: 10, - readBytes: "0123456789", - leftBytes: "01234567890123456789", - }, - { - comment: "largeVV, multiple buffers, short read", - vv: vv(13, "123", "345", "567", "8910"), - bytesToRead: 6, - readBytes: "123345", - leftBytes: "5678910", - }, - { - comment: "smallVV, large read", - vv: vv(3, "1", "2", "3"), - bytesToRead: 10, - readBytes: "123", - leftBytes: "", - }, - { - comment: "smallVV, large read", - vv: vv(1, "1"), - bytesToRead: 10, - readBytes: "1", - leftBytes: "", - }, - { - comment: "emptyVV, large read", - vv: vv(0, ""), - bytesToRead: 10, - readBytes: "", - wantError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.comment, func(t *testing.T) { - readTo := buffer.NewView(tc.bytesToRead) - inSize := tc.vv.Size() - copied, err := tc.vv.Read(readTo) - if !tc.wantError && err != nil { - t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err) - } - readTo = readTo[:copied] - if got, want := copied, len(tc.readBytes); got != want { - t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) - } - if got, want := string(readTo), tc.readBytes; got != want { - t.Errorf("unexpected data in readTo got: %s, want: %s", got, want) - } - if got, want := tc.vv.Size(), inSize-copied; got != want { - t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) - } - if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want { - t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want) - } - }) - } -} - -var pullUpTestCases = []struct { - comment string - in buffer.VectorisedView - count int - want []byte - result buffer.VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, buffer.View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} - -func TestToVectorisedView(t *testing.T) { - testCases := []struct { - in buffer.View - want buffer.VectorisedView - }{ - {nil, buffer.VectorisedView{}}, - {buffer.View{}, buffer.VectorisedView{}}, - {buffer.View{'a'}, buffer.NewVectorisedView(1, []buffer.View{{'a'}})}, - } - for _, tc := range testCases { - if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) - } - } -} - -func TestAppendView(t *testing.T) { - testCases := []struct { - vv buffer.VectorisedView - in buffer.View - want buffer.VectorisedView - }{ - {buffer.VectorisedView{}, nil, buffer.VectorisedView{}}, - {buffer.VectorisedView{}, buffer.View{}, buffer.VectorisedView{}}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), nil, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{}, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{'e'}, buffer.NewVectorisedView(5, []buffer.View{{'a', 'b', 'c', 'd'}, {'e'}})}, - } - for _, tc := range testCases { - tc.vv.AppendView(tc.in) - if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) - } - } -} - -func TestMemSize(t *testing.T) { - const perViewCap = 128 - views := make([]buffer.View, 2, 32) - views[0] = make(buffer.View, 10, perViewCap) - views[1] = make(buffer.View, 20, perViewCap) - vv := buffer.NewVectorisedView(30, views) - want := int(unsafe.Sizeof(vv)) + cap(views)*int(unsafe.Sizeof(views)) + 2*perViewCap - if got := vv.MemSize(); got != want { - t.Errorf("vv.MemSize() = %d, want %d", got, want) - } -} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD deleted file mode 100644 index c984470e6..000000000 --- a/pkg/tcpip/checker/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "checker", - testonly = 1, - srcs = ["checker.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go deleted file mode 100644 index 07b4393a4..000000000 --- a/pkg/tcpip/checker/checker.go +++ /dev/null @@ -1,1617 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package checker provides helper functions to check networking packets for -// validity. -package checker - -import ( - "encoding/binary" - "reflect" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// NetworkChecker is a function to check a property of a network packet. -type NetworkChecker func(*testing.T, []header.Network) - -// TransportChecker is a function to check a property of a transport packet. -type TransportChecker func(*testing.T, header.Transport) - -// ControlMessagesChecker is a function to check a property of ancillary data. -type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) - -// IPv4 checks the validity and properties of the given IPv4 packet. It is -// expected to be used in conjunction with other network checkers for specific -// properties. For example, to check the source and destination address, one -// would call: -// -// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) -func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv4 := header.IPv4(b) - - if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") - } - - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) - } - - for _, f := range checkers { - f(t, []header.Network{ipv4}) - } - if t.Failed() { - t.FailNow() - } -} - -// IPv6 checks the validity and properties of the given IPv6 packet. The usage -// is similar to IPv4. -func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") - } - - for _, f := range checkers { - f(t, []header.Network{ipv6}) - } - if t.Failed() { - t.FailNow() - } -} - -// SrcAddr creates a checker that checks the source address. -func SrcAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].SourceAddress(); a != addr { - t.Errorf("Bad source address, got %v, want %v", a, addr) - } - } -} - -// DstAddr creates a checker that checks the destination address. -func DstAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].DestinationAddress(); a != addr { - t.Errorf("Bad destination address, got %v, want %v", a, addr) - } - } -} - -// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). -func TTL(ttl uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - var v uint8 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TTL() - case header.IPv6: - v = ip.HopLimit() - case *ipv6HeaderWithExtHdr: - v = ip.HopLimit() - default: - t.Fatalf("unrecognized header type %T for TTL evaluation", ip) - } - if v != ttl { - t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) - } - } -} - -// IPFullLength creates a checker for the full IP packet length. The -// expected size is checked against both the Total Length in the -// header and the number of bytes received. -func IPFullLength(packetLength uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - var v uint16 - var l uint16 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TotalLength() - l = uint16(len(ip)) - case header.IPv6: - v = ip.PayloadLength() + header.IPv6FixedHeaderSize - l = uint16(len(ip)) - default: - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) - } - if l != packetLength { - t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) - } - if v != packetLength { - t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) - } - } -} - -// IPv4HeaderLength creates a checker that checks the IPv4 Header length. -func IPv4HeaderLength(headerLength int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - switch ip := h[0].(type) { - case header.IPv4: - if hl := ip.HeaderLength(); hl != uint8(headerLength) { - t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) - } - default: - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) - } - } -} - -// PayloadLen creates a checker that checks the payload length. -func PayloadLen(payloadLength int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if l := len(h[0].Payload()); l != payloadLength { - t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) - } - } -} - -// IPPayload creates a checker that checks the payload. -func IPPayload(payload []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - got := h[0].Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(got) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(payload, got); diff != "" { - t.Errorf("payload mismatch (-want +got):\n%s", diff) - } - } -} - -// IPv4Options returns a checker that checks the options in an IPv4 packet. -func IPv4Options(want header.IPv4Options) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - ip, ok := h[0].(header.IPv4) - if !ok { - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) - } - options := ip.Options() - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(options) == 0 { - return - } - if diff := cmp.Diff(want, options); diff != "" { - t.Errorf("options mismatch (-want +got):\n%s", diff) - } - } -} - -// IPv4RouterAlert returns a checker that checks that the RouterAlert option is -// set in an IPv4 packet. -func IPv4RouterAlert() NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - ip, ok := h[0].(header.IPv4) - if !ok { - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) - } - iterator := ip.Options().MakeIterator() - for { - opt, done, err := iterator.Next() - if err != nil { - t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer) - } - if done { - break - } - if opt.Type() != header.IPv4OptionRouterAlertType { - continue - } - want := [header.IPv4OptionRouterAlertLength]byte{ - byte(header.IPv4OptionRouterAlertType), - header.IPv4OptionRouterAlertLength, - header.IPv4OptionRouterAlertValue, - header.IPv4OptionRouterAlertValue, - } - if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { - t.Errorf("router alert option mismatch (-want +got):\n%s", diff) - } - return - } - t.Errorf("failed to find router alert option in %v", ip.Options()) - } -} - -// FragmentOffset creates a checker that checks the FragmentOffset field. -func FragmentOffset(offset uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this for IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) - } - } - } -} - -// FragmentFlags creates a checker that checks the fragment flags field. -func FragmentFlags(flags uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this for IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) - } - } - } -} - -// ReceiveTClass creates a checker that checks the TCLASS field in -// ControlMessages. -func ReceiveTClass(want uint32) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTClass { - t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) - } else if got := cm.TClass; got != want { - t.Errorf("got cm.TClass = %d, want %d", got, want) - } - } -} - -// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. -func ReceiveTOS(want uint8) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTOS { - t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) - } else if got := cm.TOS; got != want { - t.Errorf("got cm.TOS = %d, want %d", got, want) - } - } -} - -// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in -// ControlMessages. -func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasIPPacketInfo { - t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) - } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { - t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) - } - } -} - -// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress -// field in ControlMessages. -func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasOriginalDstAddress { - t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) - } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { - t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) - } - } -} - -// TOS creates a checker that checks the TOS field. -func TOS(tos uint8, label uint32) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) - } - } -} - -// Raw creates a checker that checks the bytes of payload. -// The checker always checks the payload of the last network header. -// For instance, in case of IPv6 fragments, the payload that will be checked -// is the one containing the actual data that the packet is carrying, without -// the bytes added by the IPv6 fragmentation. -func Raw(want []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// IPv6Fragment creates a checker that validates an IPv6 fragment. -func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) - } - - ipv6Frag := header.IPv6Fragment(h[0].Payload()) - if !ipv6Frag.IsValid() { - t.Error("Not a valid IPv6 fragment") - } - - for _, f := range checkers { - f(t, []header.Network{h[0], ipv6Frag}) - } - if t.Failed() { - t.FailNow() - } - } -} - -// TCP creates a checker that checks that the transport protocol is TCP and -// potentially additional transport header fields. -func TCP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - first := h[0] - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) - } - - // Verify the checksum. - tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) - } - - // Run the transport checkers. - for _, f := range checkers { - f(t, tcp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// UDP creates a checker that checks that the transport protocol is UDP and -// potentially additional transport header fields. -func UDP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) - } - - udp := header.UDP(last.Payload()) - for _, f := range checkers { - f(t, udp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// SrcPort creates a checker that checks the source port. -func SrcPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got = %d, want = %d", p, port) - } - } -} - -// DstPort creates a checker that checks the destination port. -func DstPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got = %d, want = %d", p, port) - } - } -} - -// NoChecksum creates a checker that checks if the checksum is zero. -func NoChecksum(noChecksum bool) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - udp, ok := h.(header.UDP) - if !ok { - t.Fatalf("UDP header not found in h: %T", h) - } - - if b := udp.Checksum() == 0; b != noChecksum { - t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) - } - } -} - -// TCPSeqNum creates a checker that checks the sequence number. -func TCPSeqNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) - } - } -} - -// TCPAckNum creates a checker that checks the ack number. -func TCPAckNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got = %d, want = %d", s, seq) - } - } -} - -// TCPWindow creates a checker that checks the tcp window. -func TCPWindow(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in hdr : %T", h) - } - - if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got %d, want %d", w, window) - } - } -} - -// TCPWindowGreaterThanEq creates a checker that checks that the TCP window -// is greater than or equal to the provided value. -func TCPWindowGreaterThanEq(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if w := tcp.WindowSize(); w < window { - t.Errorf("Bad window, got %d, want > %d", w, window) - } - } -} - -// TCPWindowLessThanEq creates a checker that checks that the tcp window -// is less than or equal to the provided value. -func TCPWindowLessThanEq(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if w := tcp.WindowSize(); w > window { - t.Errorf("Bad window, got %d, want < %d", w, window) - } - } -} - -// TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) - } - } -} - -// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the -// given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) - } - } -} - -// TCPSynOptions creates a checker that checks the presence of TCP options in -// SYN segments. -// -// If wndscale is negative, the window scale option must not be present. -func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := tcp.Options() - limit := len(opts) - foundMSS := false - foundWS := false - foundTS := false - foundSACKPermitted := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionMSS: - v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) - if wantOpts.MSS != v { - t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) - } - foundMSS = true - i += 4 - case header.TCPOptionWS: - if wantOpts.WS < 0 { - t.Error("WS present when it shouldn't be") - } - v := int(opts[i+2]) - if v != wantOpts.WS { - t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) - } - foundWS = true - i += 3 - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = uint32(0) - if tcp.Flags()&header.TCPFlagAck != 0 { - // If the syn is an SYN-ACK then read - // the tsEcr value as well. - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - } - foundTS = true - i += 10 - case header.TCPOptionSACKPermitted: - if i+1 >= limit { - t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) - } - if opts[i+1] != 2 { - t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) - } - foundSACKPermitted = true - i += 2 - - default: - i += int(opts[i+1]) - } - } - - if !foundMSS { - t.Errorf("MSS option not found. Options: %x", opts) - } - - if !foundWS && wantOpts.WS >= 0 { - t.Errorf("WS option not found. Options: %x", opts) - } - if wantOpts.TS && !foundTS { - t.Errorf("TS option not found. Options: %x", opts) - } - if foundTS && tsVal == 0 { - t.Error("TS option specified but the timestamp value is zero") - } - if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) - } - if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Errorf("SACKPermitted option not found. Options: %x", opts) - } - } -} - -// TCPTimestampChecker creates a checker that validates that a TCP segment has a -// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and -// wantTSEcr values with those in the TCP segment (if present). -// -// If wantTSVal or wantTSEcr is zero then the corresponding comparison is -// skipped. -func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := []byte(tcp.Options()) - limit := len(opts) - foundTS := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - foundTS = true - i += 10 - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - return - } - l := int(opts[i+1]) - if i < 2 || i+l > limit { - return - } - i += l - } - } - - if wantTS != foundTS { - t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) - } - if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) - } - if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) - } - } -} - -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - -// TCPSACKBlockChecker creates a checker that verifies that the segment does -// contain the specified SACK blocks in the TCP options. -func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - var gotSACKBlocks []header.SACKBlock - - opts := []byte(tcp.Options()) - limit := len(opts) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionSACK: - if i+2 > limit { - // Malformed SACK block. - t.Errorf("malformed SACK option in options: %v", opts) - } - sackOptionLen := int(opts[i+1]) - if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { - // Malformed SACK block. - t.Errorf("malformed SACK option length in options: %v", opts) - } - numBlocks := sackOptionLen / 8 - for j := 0; j < numBlocks; j++ { - start := binary.BigEndian.Uint32(opts[i+2+j*8:]) - end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) - gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ - Start: seqnum.Value(start), - End: seqnum.Value(end), - }) - } - i += sackOptionLen - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - break - } - l := int(opts[i+1]) - if l < 2 || i+l > limit { - break - } - i += l - } - } - - if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) - } - } -} - -// Payload creates a checker that checks the payload. -func Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 -// and potentially additional ICMPv4 header fields. -func ICMPv4(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) - } - - icmp := header.ICMPv4(last.Payload()) - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv4Type creates a checker that checks the ICMPv4 Type field. -func ICMPv4Type(want header.ICMPv4Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want header.ICMPv4Code) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. -func ICMPv4Ident(want uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Ident(); got != want { - t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. -func ICMPv4Seq(want uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Sequence(); got != want { - t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. -func ICMPv4Pointer(want uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Pointer(); got != want { - t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. -// This assumes that the payload exactly makes up the rest of the slice. -func ICMPv4Checksum() TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - heldChecksum := icmpv4.Checksum() - icmpv4.SetChecksum(0) - newChecksum := ^header.Checksum(icmpv4, 0) - icmpv4.SetChecksum(heldChecksum) - if heldChecksum != newChecksum { - t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) - } - } -} - -// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. -func ICMPv4Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - payload := icmpv4.Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(want, payload); diff != "" { - t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) - } - } -} - -// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and -// potentially additional ICMPv6 header fields. -// -// ICMPv6 will validate the checksum field before calling checkers. -func ICMPv6(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) - } - - icmp := header.ICMPv6(last.Payload()) - if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { - t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv6Type creates a checker that checks the ICMPv6 Type field. -func ICMPv6Type(want header.ICMPv6Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want header.ICMPv6Code) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific -// field. -func ICMPv6TypeSpecific(want uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.TypeSpecific(); got != want { - t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. -func ICMPv6Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - payload := icmpv6.Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(want, payload); diff != "" { - t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) - } - } -} - -// MLD creates a checker that checks that the packet contains a valid MLD -// message for type of mldType, with potentially additional checks specified by -// checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// MLD message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.MessageBody()); got < minSize { - t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay -// field of a MLD message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid MLD message as far as the size is concerned. -func MLDMaxRespDelay(want time.Duration) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.MLD(icmp.MessageBody()) - - if got := ns.MaximumResponseDelay(); got != want { - t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) - } - } -} - -// MLDMulticastAddress creates a checker that checks the Multicast Address -// field of a MLD message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid MLD message as far as the size is concerned. -func MLDMulticastAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.MLD(icmp.MessageBody()) - - if got := ns.MulticastAddress(); got != want { - t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) - } - } -} - -// NDP creates a checker that checks that the packet contains a valid NDP -// message for type of ty, with potentially additional checks specified by -// checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDP message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.MessageBody()); got < minSize { - t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// NDPNS creates a checker that checks that the packet contains a valid NDP -// Neighbor Solicitation message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the message is concerned. The values -// within the message are up to checkers to validate. -func NDPNS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) -} - -// NDPNSTargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborSolicit. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSTargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - - if got := ns.TargetAddress(); got != want { - t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) - } - } -} - -// NDPNA creates a checker that checks that the packet contains a valid NDP -// Neighbor Advertisement message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNA message as far as the size of the message is concerned. The values -// within the message are up to checkers to validate. -func NDPNA(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) -} - -// NDPNATargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborAdvert. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNATargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - - if got := na.TargetAddress(); got != want { - t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) - } - } -} - -// NDPNASolicitedFlag creates a checker that checks the Solicited field of -// a header.NDPNeighborAdvert. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNASolicitedFlag(want bool) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - - if got := na.SolicitedFlag(); got != want { - t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) - } - } -} - -// ndpOptions checks that optsBuf only contains opts. -func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { - t.Helper() - - it, err := optsBuf.Iter(true) - if err != nil { - t.Errorf("optsBuf.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - t.Fatalf("unexpected error when iterating over NDP options: %s", err) - } - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - case header.NDPTargetLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - default: - t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } -} - -// NDPNAOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Neighbor Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNAOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - ndpOptions(t, na.Options(), opts) - } -} - -// NDPNSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Neighbor Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ndpOptions(t, ns.Options(), opts) - } -} - -// NDPRS creates a checker that checks that the packet contains a valid NDP -// Router Solicitation message (as per the raw wire format). -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPRS as far as the size of the message is concerned. The values within the -// message are up to checkers to validate. -func NDPRS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) -} - -// NDPRSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Router Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPRS message as far as the size is concerned. -func NDPRSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.MessageBody()) - ndpOptions(t, rs.Options(), opts) - } -} - -// IGMP checks the validity and properties of the given IGMP packet. It is -// expected to be used in conjunction with other IGMP transport checkers for -// specific properties. -func IGMP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) - } - - igmp := header.IGMP(last.Payload()) - for _, f := range checkers { - f(t, igmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// IGMPType creates a checker that checks the IGMP Type field. -func IGMPType(want header.IGMPType) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.Type(); got != want { - t.Errorf("got igmp.Type() = %d, want = %d", got, want) - } - } -} - -// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. -func IGMPMaxRespTime(want time.Duration) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.MaxRespTime(); got != want { - t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) - } - } -} - -// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. -func IGMPGroupAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.GroupAddress(); got != want { - t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) - } - } -} - -// IPv6ExtHdrChecker is a function to check an extension header. -type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) - -// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. -func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("not a valid IPv6 packet") - return - } - - payloadIterator := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), - buffer.View(ipv6.Payload()).ToVectorisedView(), - ) - - var rawPayloadHeader header.IPv6RawPayloadHeader - for { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) - return - } - r, ok := h.(header.IPv6RawPayloadHeader) - if ok { - rawPayloadHeader = r - break - } - } - - networkHeader := ipv6HeaderWithExtHdr{ - IPv6: ipv6, - transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), - payload: rawPayloadHeader.Buf.ToView(), - } - - for _, checker := range checkers { - checker(t, []header.Network{&networkHeader}) - } -} - -// IPv6ExtHdr checks for the presence of extension headers. -// -// All the extension headers in headers will be checked exhaustively in the -// order provided. -func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) - if !ok { - t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) - return - } - - payloadIterator := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), - buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), - ) - - for _, check := range headers { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) - return - } - check(t, h) - } - // Validate we consumed all headers. - // - // The next one over should be a raw payload and then iterator should - // terminate. - wantDone := false - for { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done != wantDone { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) - return - } - if done { - break - } - if _, ok := h.(header.IPv6RawPayloadHeader); !ok { - t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) - continue - } - wantDone = true - } - } -} - -var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) - -// ipv6HeaderWithExtHdr provides a header.Network implementation that takes -// extension headers into consideration, which is not the case with vanilla -// header.IPv6. -type ipv6HeaderWithExtHdr struct { - header.IPv6 - transport tcpip.TransportProtocolNumber - payload []byte -} - -// TransportProtocol implements header.Network. -func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { - return h.transport -} - -// Payload implements header.Network. -func (h *ipv6HeaderWithExtHdr) Payload() []byte { - return h.payload -} - -// IPv6ExtHdrOptionChecker is a function to check an extension header option. -type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) - -// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop -// extension header and validates the containing options with checkers. -// -// checkers must exhaustively contain all the expected options. -func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { - return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { - t.Helper() - - hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) - if !ok { - t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) - return - } - optionsIterator := hbh.Iter() - for _, f := range checkers { - opt, done, err := optionsIterator.Next() - if err != nil { - t.Errorf("optionsIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) - } - f(t, opt) - } - // Validate all options were consumed. - for { - opt, done, err := optionsIterator.Next() - if err != nil { - t.Errorf("optionsIterator.Next(): %s", err) - return - } - if !done { - t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) - } - if done { - break - } - } - } -} - -// IPv6RouterAlert validates that an extension header option is the RouterAlert -// option and matches on its value. -func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { - return func(t *testing.T, opt header.IPv6ExtHdrOption) { - routerAlert, ok := opt.(*header.IPv6RouterAlertOption) - if !ok { - t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) - return - } - if routerAlert.Value != want { - t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) - } - } -} - -// IgnoreCmpPath returns a cmp.Option that ignores listed field paths. -func IgnoreCmpPath(paths ...string) cmp.Option { - ignores := map[string]struct{}{} - for _, path := range paths { - ignores[path] = struct{}{} - } - return cmp.FilterPath(func(path cmp.Path) bool { - _, ok := ignores[path.String()] - return ok - }, cmp.Ignore()) -} diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD deleted file mode 100644 index bb9d44aff..000000000 --- a/pkg/tcpip/faketime/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "faketime", - srcs = ["faketime.go"], - visibility = ["//visibility:public"], - deps = ["//pkg/tcpip"], -) - -go_test( - name = "faketime_test", - size = "small", - srcs = [ - "faketime_test.go", - ], - deps = [ - "//pkg/tcpip/faketime", - ], -) diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go deleted file mode 100644 index fb819d7a8..000000000 --- a/pkg/tcpip/faketime/faketime.go +++ /dev/null @@ -1,358 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package faketime provides a fake clock that implements tcpip.Clock interface. -package faketime - -import ( - "container/heap" - "fmt" - "sync" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -// NullClock implements a clock that never advances. -type NullClock struct{} - -var _ tcpip.Clock = (*NullClock)(nil) - -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (*NullClock) NowNanoseconds() int64 { - return 0 -} - -// NowMonotonic implements tcpip.Clock.NowMonotonic. -func (*NullClock) NowMonotonic() int64 { - return 0 -} - -// AfterFunc implements tcpip.Clock.AfterFunc. -func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { - return nil -} - -type notificationChannels struct { - mu struct { - sync.Mutex - - ch []<-chan struct{} - } -} - -func (n *notificationChannels) add(ch <-chan struct{}) { - n.mu.Lock() - defer n.mu.Unlock() - n.mu.ch = append(n.mu.ch, ch) -} - -// wait returns once all the notification channels are readable. -// -// Channels that are added while waiting on existing channels will be waited on -// as well. -func (n *notificationChannels) wait() { - for { - n.mu.Lock() - ch := n.mu.ch - n.mu.ch = nil - n.mu.Unlock() - - if len(ch) == 0 { - break - } - - for _, c := range ch { - <-c - } - } -} - -// ManualClock implements tcpip.Clock and only advances manually with Advance -// method. -type ManualClock struct { - // runningTimers tracks the completion of timer callbacks that began running - // immediately upon their scheduling. It is used to ensure the proper ordering - // of timer callback dispatch. - runningTimers notificationChannels - - mu struct { - sync.RWMutex - - // now is the current (fake) time of the clock. - now time.Time - - // times is min-heap of times. - times timeHeap - - // timers holds the timers scheduled for each time. - timers map[time.Time]map[*manualTimer]struct{} - } -} - -// NewManualClock creates a new ManualClock instance. -func NewManualClock() *ManualClock { - c := &ManualClock{} - - c.mu.Lock() - defer c.mu.Unlock() - - // Set the initial time to a non-zero value since the zero value is used to - // detect inactive timers. - c.mu.now = time.Unix(0, 0) - c.mu.timers = make(map[time.Time]map[*manualTimer]struct{}) - - return c -} - -var _ tcpip.Clock = (*ManualClock)(nil) - -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (mc *ManualClock) NowNanoseconds() int64 { - mc.mu.RLock() - defer mc.mu.RUnlock() - return mc.mu.now.UnixNano() -} - -// NowMonotonic implements tcpip.Clock.NowMonotonic. -func (mc *ManualClock) NowMonotonic() int64 { - return mc.NowNanoseconds() -} - -// AfterFunc implements tcpip.Clock.AfterFunc. -func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - mt := &manualTimer{ - clock: mc, - f: f, - } - - mc.mu.Lock() - defer mc.mu.Unlock() - - mt.mu.Lock() - defer mt.mu.Unlock() - - mc.resetTimerLocked(mt, d) - return mt -} - -// resetTimerLocked schedules a timer to be fired after the given duration. -// -// Precondition: mc.mu and mt.mu must be locked. -func (mc *ManualClock) resetTimerLocked(mt *manualTimer, d time.Duration) { - if !mt.mu.firesAt.IsZero() { - panic("tried to reset an active timer") - } - - t := mc.mu.now.Add(d) - - if !mc.mu.now.Before(t) { - // If the timer is scheduled to fire immediately, call its callback - // in a new goroutine immediately. - // - // It needs to be called in its own goroutine to escape its current - // execution context - like an actual timer. - ch := make(chan struct{}) - mc.runningTimers.add(ch) - - go func() { - defer close(ch) - - mt.f() - }() - - return - } - - mt.mu.firesAt = t - - timers, ok := mc.mu.timers[t] - if !ok { - timers = make(map[*manualTimer]struct{}) - mc.mu.timers[t] = timers - heap.Push(&mc.mu.times, t) - } - - timers[mt] = struct{}{} -} - -// stopTimerLocked stops a timer from firing. -// -// Precondition: mc.mu and mt.mu must be locked. -func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { - t := mt.mu.firesAt - mt.mu.firesAt = time.Time{} - - if t.IsZero() { - panic("tried to stop an inactive timer") - } - - timers, ok := mc.mu.timers[t] - if !ok { - err := fmt.Sprintf("tried to stop an active timer but the clock does not have anything scheduled for the timer @ t = %s %p\nScheduled timers @:", t.UTC(), mt) - for t := range mc.mu.timers { - err += fmt.Sprintf("%s\n", t.UTC()) - } - panic(err) - } - - if _, ok := timers[mt]; !ok { - panic(fmt.Sprintf("did not have an entry in timers for an active timer @ t = %s", t.UTC())) - } - - delete(timers, mt) - - if len(timers) == 0 { - delete(mc.mu.timers, t) - } -} - -// Advance executes all work that have been scheduled to execute within d from -// the current time. Blocks until all work has completed execution. -func (mc *ManualClock) Advance(d time.Duration) { - // We spawn goroutines for timers that were scheduled to fire at the time of - // being reset. Wait for those goroutines to complete before proceeding so - // that timer callbacks are called in the right order. - mc.runningTimers.wait() - - mc.mu.Lock() - defer mc.mu.Unlock() - - until := mc.mu.now.Add(d) - for mc.mu.times.Len() > 0 { - t := heap.Pop(&mc.mu.times).(time.Time) - if t.After(until) { - // No work to do - heap.Push(&mc.mu.times, t) - break - } - - timers := mc.mu.timers[t] - delete(mc.mu.timers, t) - - mc.mu.now = t - - // Mark the timers as inactive since they will be fired. - // - // This needs to be done while holding mc's lock because we remove the entry - // in the map of timers for the current time. If an attempt to stop a - // timer is made after mc's lock was dropped but before the timer is - // marked inactive, we would panic since no entry exists for the time when - // the timer was expected to fire. - for mt := range timers { - mt.mu.Lock() - mt.mu.firesAt = time.Time{} - mt.mu.Unlock() - } - - // Release the lock before calling the timer's callback fn since the - // callback fn might try to schedule a timer which requires obtaining - // mc's lock. - mc.mu.Unlock() - - for mt := range timers { - mt.f() - } - - // The timer callbacks may have scheduled a timer to fire immediately. - // We spawn goroutines for these timers and need to wait for them to - // finish before proceeding so that timer callbacks are called in the - // right order. - mc.runningTimers.wait() - mc.mu.Lock() - } - - mc.mu.now = until -} - -func (mc *ManualClock) resetTimer(mt *manualTimer, d time.Duration) { - mc.mu.Lock() - defer mc.mu.Unlock() - - mt.mu.Lock() - defer mt.mu.Unlock() - - if !mt.mu.firesAt.IsZero() { - mc.stopTimerLocked(mt) - } - - mc.resetTimerLocked(mt, d) -} - -func (mc *ManualClock) stopTimer(mt *manualTimer) bool { - mc.mu.Lock() - defer mc.mu.Unlock() - - mt.mu.Lock() - defer mt.mu.Unlock() - - if mt.mu.firesAt.IsZero() { - return false - } - - mc.stopTimerLocked(mt) - return true -} - -type manualTimer struct { - clock *ManualClock - f func() - - mu struct { - sync.Mutex - - // firesAt is the time when the timer will fire. - // - // Zero only when the timer is not active. - firesAt time.Time - } -} - -var _ tcpip.Timer = (*manualTimer)(nil) - -// Reset implements tcpip.Timer.Reset. -func (mt *manualTimer) Reset(d time.Duration) { - mt.clock.resetTimer(mt, d) -} - -// Stop implements tcpip.Timer.Stop. -func (mt *manualTimer) Stop() bool { - return mt.clock.stopTimer(mt) -} - -type timeHeap []time.Time - -var _ heap.Interface = (*timeHeap)(nil) - -func (h timeHeap) Len() int { - return len(h) -} - -func (h timeHeap) Less(i, j int) bool { - return h[i].Before(h[j]) -} - -func (h timeHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -func (h *timeHeap) Push(x interface{}) { - *h = append(*h, x.(time.Time)) -} - -func (h *timeHeap) Pop() interface{} { - last := (*h)[len(*h)-1] - *h = (*h)[:len(*h)-1] - return last -} diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go deleted file mode 100644 index c2704df2c..000000000 --- a/pkg/tcpip/faketime/faketime_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package faketime_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip/faketime" -) - -func TestManualClockAdvance(t *testing.T) { - const timeout = time.Millisecond - clock := faketime.NewManualClock() - start := clock.NowMonotonic() - clock.Advance(timeout) - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestManualClockAfterFunc(t *testing.T) { - const ( - timeout1 = time.Millisecond // timeout for counter1 - timeout2 = 2 * time.Millisecond // timeout for counter2 - ) - tests := []struct { - name string - advance time.Duration - wantCounter1 int - wantCounter2 int - }{ - { - name: "before timeout1", - advance: timeout1 - 1, - wantCounter1: 0, - wantCounter2: 0, - }, - { - name: "timeout1", - advance: timeout1, - wantCounter1: 1, - wantCounter2: 0, - }, - { - name: "timeout2", - advance: timeout2, - wantCounter1: 1, - wantCounter2: 1, - }, - { - name: "after timeout2", - advance: timeout2 + 1, - wantCounter1: 1, - wantCounter2: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - counter1 := 0 - counter2 := 0 - clock.AfterFunc(timeout1, func() { - counter1++ - }) - clock.AfterFunc(timeout2, func() { - counter2++ - }) - start := clock.NowMonotonic() - clock.Advance(test.advance) - if got, want := counter1, test.wantCounter1; got != want { - t.Errorf("got counter1 = %d, want = %d", got, want) - } - if got, want := counter2, test.wantCounter2; got != want { - t.Errorf("got counter2 = %d, want = %d", got, want) - } - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { - t.Errorf("got elapsed = %d, want = %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD deleted file mode 100644 index ff2719291..000000000 --- a/pkg/tcpip/hash/jenkins/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "jenkins", - srcs = ["jenkins.go"], - visibility = ["//visibility:public"], -) - -go_test( - name = "jenkins_test", - size = "small", - srcs = [ - "jenkins_test.go", - ], - library = ":jenkins", -) diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go new file mode 100644 index 000000000..216cc5a2e --- /dev/null +++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package jenkins diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go deleted file mode 100644 index 4c78b5808..000000000 --- a/pkg/tcpip/hash/jenkins/jenkins_test.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package jenkins - -import ( - "bytes" - "encoding/binary" - "hash" - "hash/fnv" - "math" - "testing" -) - -func TestGolden32(t *testing.T) { - var golden32 = []struct { - out []byte - in string - }{ - {[]byte{0x00, 0x00, 0x00, 0x00}, ""}, - {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"}, - {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"}, - {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"}, - } - - hash := New32() - - for _, g := range golden32 { - hash.Reset() - done, error := hash.Write([]byte(g.in)) - if error != nil { - t.Fatalf("write error: %s", error) - } - if done != len(g.in) { - t.Fatalf("wrote only %d out of %d bytes", done, len(g.in)) - } - if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) { - t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out) - } - } -} - -func TestIntegrity32(t *testing.T) { - data := []byte{'1', '2', 3, 4, 5} - - h := New32() - h.Write(data) - sum := h.Sum(nil) - - if size := h.Size(); size != len(sum) { - t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum)) - } - - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data[:2]) - h.Write(data[2:]) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a) - } - - sum32 := h.(hash.Hash32).Sum32() - if sum32 != binary.BigEndian.Uint32(sum) { - t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32) - } -} - -func BenchmarkJenkins32KB(b *testing.B) { - h := New32() - - b.SetBytes(1024) - data := make([]byte, 1024) - for i := range data { - data[i] = byte(i) - } - in := make([]byte, 0, h.Size()) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - h.Reset() - h.Write(data) - h.Sum(in) - } -} - -func BenchmarkFnv32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - - h := fnv.New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - c := 0 - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - if c == 0 { - b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N) - } - c++ - } - } - if c > 0 { - b.Logf("Unbalanced buckets: %d", c) - } - } -} - -func BenchmarkSum32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := Sum32(0) - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} - -func BenchmarkNew32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD deleted file mode 100644 index 0bdc12d53..000000000 --- a/pkg/tcpip/header/BUILD +++ /dev/null @@ -1,74 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "header", - srcs = [ - "arp.go", - "checksum.go", - "eth.go", - "gue.go", - "icmpv4.go", - "icmpv6.go", - "igmp.go", - "interfaces.go", - "ipv4.go", - "ipv6.go", - "ipv6_extension_headers.go", - "ipv6_fragment.go", - "mld.go", - "ndp_neighbor_advert.go", - "ndp_neighbor_solicit.go", - "ndp_options.go", - "ndp_router_advert.go", - "ndp_router_solicit.go", - "ndpoptionidentifier_string.go", - "tcp.go", - "udp.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/seqnum", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "header_x_test", - size = "small", - srcs = [ - "checksum_test.go", - "igmp_test.go", - "ipv4_test.go", - "ipv6_test.go", - "ipversion_test.go", - "tcp_test.go", - ], - deps = [ - ":header", - "//pkg/rand", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "header_test", - size = "small", - srcs = [ - "eth_test.go", - "ipv6_extension_headers_test.go", - "mld_test.go", - "ndp_test.go", - ], - library = ":header", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go deleted file mode 100644 index 5ab20ee86..000000000 --- a/pkg/tcpip/header/checksum_test.go +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package header provides the implementation of the encoding and decoding of -// network protocol headers. -package header_test - -import ( - "fmt" - "math/rand" - "sync" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestChecksumVVWithOffset(t *testing.T) { - testCases := []struct { - name string - vv buffer.VectorisedView - off, size int - initial uint16 - want uint16 - }{ - { - name: "empty", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 0, - want: 0, - }, - { - name: "OneView", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 5, - want: 1294, - }, - { - name: "TwoViews", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 0, - size: 11, - want: 33819, - }, - { - name: "TwoViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 1, - size: 11, - want: 33819, - }, - { - name: "ThreeViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 7, - size: 11, - want: 33819, - }, - { - name: "ThreeViewsWithInitial", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), - }), - initial: 77, - off: 7, - size: 11, - want: 33896, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { - t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) - } - v := tc.vv.ToView() - v.TrimFront(tc.off) - v.CapLength(tc.size) - if got, want := header.Checksum(v, tc.initial), tc.want; got != want { - t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) - } - }) - } -} - -func TestChecksum(t *testing.T) { - var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} - type testCase struct { - buf []byte - initial uint16 - csumOrig uint16 - csumNew uint16 - } - testCases := make([]testCase, 100000) - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for i := range testCases { - testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) - testCases[i].initial = uint16(rnd.Intn(65536)) - rnd.Read(testCases[i].buf) - } - - for i := range testCases { - testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) - testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) - if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { - t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) - } - } -} - -func BenchmarkChecksum(b *testing.B) { - var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} - - checkSumImpls := []struct { - fn func([]byte, uint16) uint16 - name string - }{ - {header.ChecksumOld, fmt.Sprintf("checksum_old")}, - {header.Checksum, fmt.Sprintf("checksum")}, - } - - for _, csumImpl := range checkSumImpls { - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for _, bufSz := range bufSizes { - b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { - tc := struct { - buf []byte - initial uint16 - csum uint16 - }{ - buf: make([]byte, bufSz), - initial: uint16(rnd.Intn(65536)), - } - rnd.Read(tc.buf) - b.ResetTimer() - for i := 0; i < b.N; i++ { - tc.csum = csumImpl.fn(tc.buf, tc.initial) - } - }) - } - } -} - -func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { - // icmpChecksum should not do any modifications of the header to - // calculate its checksum. Let's call it from a few go-routines and the - // race detector will trigger a warning if there are any concurrent - // read/write accesses. - - const concurrency = 5 - start := make(chan int) - ready := make(chan bool, concurrency) - var wg sync.WaitGroup - wg.Add(concurrency) - defer wg.Wait() - - for i := 0; i < concurrency; i++ { - go func() { - defer wg.Done() - - ready <- true - <-start - - if got := headerChecksum(); want != got { - t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) - } - if got := icmpChecksum(); want != got { - t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) - } - }() - } - for i := 0; i < concurrency; i++ { - <-ready - } - close(start) -} - -func TestICMPv4Checksum(t *testing.T) { - rnd := rand.New(rand.NewSource(42)) - - h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) - if _, err := rnd.Read(h); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - h.SetChecksum(0) - - buf := make([]byte, 13) - if _, err := rnd.Read(buf); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - vv := buffer.NewVectorisedView(len(buf), []buffer.View{ - buffer.NewViewFromBytes(buf[:5]), - buffer.NewViewFromBytes(buf[5:]), - }) - - want := header.Checksum(vv.ToView(), 0) - want = ^header.Checksum(h, want) - h.SetChecksum(want) - - testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv4Checksum(h, vv) - }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) -} - -func TestICMPv6Checksum(t *testing.T) { - rnd := rand.New(rand.NewSource(42)) - - h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) - if _, err := rnd.Read(h); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - h.SetChecksum(0) - - buf := make([]byte, 13) - if _, err := rnd.Read(buf); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - vv := buffer.NewVectorisedView(len(buf), []buffer.View{ - buffer.NewViewFromBytes(buf[:7]), - buffer.NewViewFromBytes(buf[7:10]), - buffer.NewViewFromBytes(buf[10:]), - }) - - dst := header.IPv6Loopback - src := header.IPv6Loopback - - want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) - want = header.Checksum(vv.ToView(), want) - want = ^header.Checksum(h, want) - h.SetChecksum(want) - - testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv6Checksum(h, src, dst, vv) - }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) -} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go deleted file mode 100644 index 3bc8b2b21..000000000 --- a/pkg/tcpip/header/eth_test.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestIsValidUnicastEthernetAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.LinkAddress - expected bool - }{ - { - "Nil", - tcpip.LinkAddress([]byte(nil)), - false, - }, - { - "Empty", - tcpip.LinkAddress(""), - false, - }, - { - "InvalidLength", - tcpip.LinkAddress("\x01\x02\x03"), - false, - }, - { - "Unspecified", - unspecifiedEthernetAddress, - false, - }, - { - "Multicast", - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - false, - }, - { - "Valid", - tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), - true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { - t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) - } - }) - } -} - -func TestIsMulticastEthernetAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.LinkAddress - expected bool - }{ - { - "Nil", - tcpip.LinkAddress([]byte(nil)), - false, - }, - { - "Empty", - tcpip.LinkAddress(""), - false, - }, - { - "InvalidLength", - tcpip.LinkAddress("\x01\x02\x03"), - false, - }, - { - "Unspecified", - unspecifiedEthernetAddress, - false, - }, - { - "Multicast", - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - true, - }, - { - "Unicast", - tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := IsMulticastEthernetAddress(test.addr); got != test.expected { - t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected) - } - }) - } -} - -func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4 Multicast without 24th bit set", - addr: "\xe0\x7e\xdc\xba", - expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", - }, - { - name: "IPv4 Multicast with 24th bit set", - addr: "\xe0\xfe\xdc\xba", - expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { - t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr) - } - }) - } -} - -func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { - addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a") - if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { - t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) - } -} diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go new file mode 100644 index 000000000..ddcc980e8 --- /dev/null +++ b/pkg/tcpip/header/header_state_autogen.go @@ -0,0 +1,70 @@ +// automatically generated by stateify. + +package header + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *SACKBlock) StateTypeName() string { + return "pkg/tcpip/header.SACKBlock" +} + +func (r *SACKBlock) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (r *SACKBlock) beforeSave() {} + +func (r *SACKBlock) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Start) + stateSinkObject.Save(1, &r.End) +} + +func (r *SACKBlock) afterLoad() {} + +func (r *SACKBlock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Start) + stateSourceObject.Load(1, &r.End) +} + +func (t *TCPOptions) StateTypeName() string { + return "pkg/tcpip/header.TCPOptions" +} + +func (t *TCPOptions) StateFields() []string { + return []string{ + "TS", + "TSVal", + "TSEcr", + "SACKBlocks", + } +} + +func (t *TCPOptions) beforeSave() {} + +func (t *TCPOptions) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.TS) + stateSinkObject.Save(1, &t.TSVal) + stateSinkObject.Save(2, &t.TSEcr) + stateSinkObject.Save(3, &t.SACKBlocks) +} + +func (t *TCPOptions) afterLoad() {} + +func (t *TCPOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.TS) + stateSourceObject.Load(1, &t.TSVal) + stateSourceObject.Load(2, &t.TSEcr) + stateSourceObject.Load(3, &t.SACKBlocks) +} + +func init() { + state.Register((*SACKBlock)(nil)) + state.Register((*TCPOptions)(nil)) +} diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go deleted file mode 100644 index b6126d29a..000000000 --- a/pkg/tcpip/header/igmp_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// TestIGMPHeader tests the functions within header.igmp -func TestIGMPHeader(t *testing.T) { - const maxRespTimeTenthSec = 0xF0 - b := []byte{ - 0x11, // IGMP Type, Membership Query - maxRespTimeTenthSec, // Maximum Response Time - 0xC0, 0xC0, // Checksum - 0x01, 0x02, 0x03, 0x04, // Group Address - } - - igmpHeader := header.IGMP(b) - - if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want { - t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want) - } - - if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want { - t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) - } - - if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want { - t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) - } - - if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { - t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) - } - - igmpType := header.IGMPv2MembershipReport - igmpHeader.SetType(igmpType) - if got := igmpHeader.Type(); got != igmpType { - t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType) - } - if got := header.IGMPType(b[0]); got != igmpType { - t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType) - } - - respTime := byte(0x02) - igmpHeader.SetMaxRespTime(respTime) - if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want { - t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) - } - - checksum := uint16(0x0102) - igmpHeader.SetChecksum(checksum) - if got := igmpHeader.Checksum(); got != checksum { - t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) - } - - groupAddress := tcpip.Address("\x04\x03\x02\x01") - igmpHeader.SetGroupAddress(groupAddress) - if got := igmpHeader.GroupAddress(); got != groupAddress { - t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) - } -} - -// TestIGMPChecksum ensures that the checksum calculator produces the expected -// checksum. -func TestIGMPChecksum(t *testing.T) { - b := []byte{ - 0x11, // IGMP Type, Membership Query - 0xF0, // Maximum Response Time - 0xC0, 0xC0, // Checksum - 0x01, 0x02, 0x03, 0x04, // Group Address - } - - igmpHeader := header.IGMP(b) - - // Calculate the initial checksum after setting the checksum temporarily to 0 - // to avoid checksumming the checksum. - initialChecksum := igmpHeader.Checksum() - igmpHeader.SetChecksum(0) - checksum := ^header.Checksum(b, 0) - igmpHeader.SetChecksum(initialChecksum) - - if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum { - t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum) - } -} - -func TestDecisecondToDuration(t *testing.T) { - const valueInDeciseconds = 5 - if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want { - t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want) - } -} diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go deleted file mode 100644 index 6475cd694..000000000 --- a/pkg/tcpip/header/ipv4_test.go +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestIPv4OptionsSerializer(t *testing.T) { - optCases := []struct { - name string - option []header.IPv4SerializableOption - expect []byte - }{ - { - name: "NOP", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableNOPOption{}, - }, - expect: []byte{1, 0, 0, 0}, - }, - { - name: "ListEnd", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableListEndOption{}, - }, - expect: []byte{0, 0, 0, 0}, - }, - { - name: "RouterAlert", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableRouterAlertOption{}, - }, - expect: []byte{148, 4, 0, 0}, - }, { - name: "NOP and RouterAlert", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableNOPOption{}, - &header.IPv4SerializableRouterAlertOption{}, - }, - expect: []byte{1, 148, 4, 0, 0, 0, 0, 0}, - }, - } - - for _, opt := range optCases { - t.Run(opt.name, func(t *testing.T) { - s := header.IPv4OptionsSerializer(opt.option) - l := s.Length() - if got := len(opt.expect); got != int(l) { - t.Fatalf("s.Length() = %d, want = %d", got, l) - } - b := make([]byte, l) - for i := range b { - // Fill the buffer with full bytes to ensure padding is being set - // correctly. - b[i] = 0xFF - } - if serializedLength := s.Serialize(b); serializedLength != l { - t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l) - } - if diff := cmp.Diff(opt.expect, b); diff != "" { - t.Errorf("mismatched serialized option (-want +got):\n%s", diff) - } - }) - } -} - -// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested -// fields when options are supplied. -func TestIPv4EncodeOptions(t *testing.T) { - tests := []struct { - name string - numberOfNops int - encodedOptions header.IPv4Options // reply should look like this - wantIHL int - }{ - { - name: "valid no options", - wantIHL: header.IPv4MinimumSize, - }, - { - name: "one byte options", - numberOfNops: 1, - encodedOptions: header.IPv4Options{1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "two byte options", - numberOfNops: 2, - encodedOptions: header.IPv4Options{1, 1, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "three byte options", - numberOfNops: 3, - encodedOptions: header.IPv4Options{1, 1, 1, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "four byte options", - numberOfNops: 4, - encodedOptions: header.IPv4Options{1, 1, 1, 1}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "five byte options", - numberOfNops: 5, - encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 8, - }, - { - name: "thirty nine byte options", - numberOfNops: 39, - encodedOptions: header.IPv4Options{ - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 0, - }, - wantIHL: header.IPv4MinimumSize + 40, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops)) - for i := range serializeOpts { - serializeOpts[i] = &header.IPv4SerializableNOPOption{} - } - paddedOptionLength := serializeOpts.Length() - ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength) - hdr := buffer.NewPrependable(int(totalLen)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - // To check the padding works, poison the last byte of the options space. - if paddedOptionLength != serializeOpts.Length() { - ip.SetHeaderLength(uint8(ipHeaderLength)) - ip.Options()[paddedOptionLength-1] = 0xff - ip.SetHeaderLength(0) - } - ip.Encode(&header.IPv4Fields{ - Options: serializeOpts, - }) - options := ip.Options() - wantOptions := test.encodedOptions - if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { - t.Errorf("got IHL of %d, want %d", got, want) - } - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(wantOptions) == 0 && len(options) == 0 { - return - } - - if diff := cmp.Diff(wantOptions, options); diff != "" { - t.Errorf("options mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go deleted file mode 100644 index 65adc6250..000000000 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ /dev/null @@ -1,1346 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "bytes" - "errors" - "io" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// Equal returns true of a and b are equivalent. -// -// Note, Equal will return true if a and b hold the same Identifier value and -// contain the same bytes in Buf, even if the bytes are split across views -// differently. -// -// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported -// fields. -func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool { - return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView()) -} - -// Equal returns true of a and b are equivalent. -// -// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs. -// -// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported -// fields. -func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool { - return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr) -} - -// Equal returns true of a and b are equivalent. -// -// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs. -// -// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported -// fields. -func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool { - return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr) -} - -func TestIPv6UnknownExtHdrOption(t *testing.T) { - tests := []struct { - name string - identifier IPv6ExtHdrOptionIdentifier - expectedUnknownAction IPv6OptionUnknownAction - }{ - { - name: "Skip with zero LSBs", - identifier: 0, - expectedUnknownAction: IPv6OptionUnknownActionSkip, - }, - { - name: "Discard with zero LSBs", - identifier: 64, - expectedUnknownAction: IPv6OptionUnknownActionDiscard, - }, - { - name: "Discard and ICMP with zero LSBs", - identifier: 128, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP, - }, - { - name: "Discard and ICMP for non multicast destination with zero LSBs", - identifier: 192, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, - }, - { - name: "Skip with non-zero LSBs", - identifier: 63, - expectedUnknownAction: IPv6OptionUnknownActionSkip, - }, - { - name: "Discard with non-zero LSBs", - identifier: 127, - expectedUnknownAction: IPv6OptionUnknownActionDiscard, - }, - { - name: "Discard and ICMP with non-zero LSBs", - identifier: 191, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP, - }, - { - name: "Discard and ICMP for non multicast destination with non-zero LSBs", - identifier: 255, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}} - if a := opt.UnknownAction(); a != test.expectedUnknownAction { - t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction) - } - }) - } - -} - -func TestIPv6OptionsExtHdrIterErr(t *testing.T) { - tests := []struct { - name string - bytes []byte - err error - }{ - { - name: "Single unknown with zero length", - bytes: []byte{255, 0}, - }, - { - name: "Single unknown with non-zero length", - bytes: []byte{255, 3, 1, 2, 3}, - }, - { - name: "Two options", - bytes: []byte{ - 255, 0, - 254, 1, 1, - }, - }, - { - name: "Three options", - bytes: []byte{ - 255, 0, - 254, 1, 1, - 253, 4, 2, 3, 4, 5, - }, - }, - { - name: "Single unknown only identifier", - bytes: []byte{255}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Single unknown too small with length = 1", - bytes: []byte{255, 1}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Single unknown too small with length = 2", - bytes: []byte{255, 2, 1}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid first with second unknown only identifier", - bytes: []byte{ - 255, 0, - 254, - }, - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid first with second unknown missing data", - bytes: []byte{ - 255, 0, - 254, 1, - }, - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid first with second unknown too small", - bytes: []byte{ - 255, 0, - 254, 2, 1, - }, - err: io.ErrUnexpectedEOF, - }, - { - name: "One Pad1", - bytes: []byte{0}, - }, - { - name: "Multiple Pad1", - bytes: []byte{0, 0, 0}, - }, - { - name: "Multiple PadN", - bytes: []byte{ - // Pad3 - 1, 1, 1, - - // Pad5 - 1, 3, 1, 2, 3, - }, - }, - { - name: "Pad5 too small middle of data buffer", - bytes: []byte{1, 3, 1, 2}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Pad5 no data", - bytes: []byte{1, 3}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Router alert without data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with partial data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with partial data and Pad1", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with extra data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with missing data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1}, - err: io.ErrUnexpectedEOF, - }, - } - - check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) { - for i := 0; ; i++ { - _, done, err := it.Next() - if err != nil { - // If we encountered a non-nil error while iterating, make sure it is - // is the same error as expectedErr. - if !errors.Is(err, expectedErr) { - t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr) - } - - return - } - if done { - // If we are done (without an error), make sure that we did not expect - // an error. - if expectedErr != nil { - t.Fatalf("expected error when iterating; want = %s", expectedErr) - } - - return - } - } - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Run("Hop By Hop", func(t *testing.T) { - extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - check(t, extHdr.Iter(), test.err) - }) - - t.Run("Destination", func(t *testing.T) { - extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - check(t, extHdr.Iter(), test.err) - }) - }) - } -} - -func TestIPv6OptionsExtHdrIter(t *testing.T) { - tests := []struct { - name string - bytes []byte - expected []IPv6ExtHdrOption - }{ - { - name: "Single unknown with zero length", - bytes: []byte{255, 0}, - expected: []IPv6ExtHdrOption{ - &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}}, - }, - }, - { - name: "Single unknown with non-zero length", - bytes: []byte{255, 3, 1, 2, 3}, - expected: []IPv6ExtHdrOption{ - &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}}, - }, - }, - { - name: "Single Pad1", - bytes: []byte{0}, - }, - { - name: "Two Pad1", - bytes: []byte{0, 0}, - }, - { - name: "Single Pad3", - bytes: []byte{1, 1, 1}, - }, - { - name: "Single Pad5", - bytes: []byte{1, 3, 1, 2, 3}, - }, - { - name: "Multiple Pad", - bytes: []byte{ - // Pad1 - 0, - - // Pad2 - 1, 0, - - // Pad3 - 1, 1, 1, - - // Pad4 - 1, 2, 1, 2, - - // Pad5 - 1, 3, 1, 2, 3, - }, - }, - { - name: "Multiple options", - bytes: []byte{ - // Pad1 - 0, - - // Unknown - 255, 0, - - // Pad2 - 1, 0, - - // Unknown - 254, 1, 1, - - // Pad3 - 1, 1, 1, - - // Unknown - 253, 4, 2, 3, 4, 5, - - // Pad4 - 1, 2, 1, 2, - }, - expected: []IPv6ExtHdrOption{ - &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}}, - &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}}, - &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}}, - }, - }, - } - - checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) { - for i, e := range expected { - opt, done, err := it.Next() - if err != nil { - t.Errorf("(i=%d) Next(): %s", i, err) - } - if done { - t.Errorf("(i=%d) unexpectedly done iterating", i) - } - if diff := cmp.Diff(e, opt); diff != "" { - t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff) - } - - if t.Failed() { - t.FailNow() - } - } - - opt, done, err := it.Next() - if err != nil { - t.Errorf("(last) Next(): %s", err) - } - if !done { - t.Errorf("(last) iterator unexpectedly not done") - } - if opt != nil { - t.Errorf("(last) got Next() = %T, want = nil", opt) - } - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Run("Hop By Hop", func(t *testing.T) { - extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - checkIter(t, extHdr.Iter(), test.expected) - }) - - t.Run("Destination", func(t *testing.T) { - extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - checkIter(t, extHdr.Iter(), test.expected) - }) - }) - } -} - -func TestIPv6RoutingExtHdr(t *testing.T) { - tests := []struct { - name string - bytes []byte - segmentsLeft uint8 - }{ - { - name: "Zeroes", - bytes: []byte{0, 0, 0, 0, 0, 0}, - segmentsLeft: 0, - }, - { - name: "Ones", - bytes: []byte{1, 1, 1, 1, 1, 1}, - segmentsLeft: 1, - }, - { - name: "Mixed", - bytes: []byte{1, 2, 3, 4, 5, 6}, - segmentsLeft: 2, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - extHdr := IPv6RoutingExtHdr(test.bytes) - if got := extHdr.SegmentsLeft(); got != test.segmentsLeft { - t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft) - } - }) - } -} - -func TestIPv6FragmentExtHdr(t *testing.T) { - tests := []struct { - name string - bytes [6]byte - fragmentOffset uint16 - more bool - id uint32 - }{ - { - name: "Zeroes", - bytes: [6]byte{0, 0, 0, 0, 0, 0}, - fragmentOffset: 0, - more: false, - id: 0, - }, - { - name: "Ones", - bytes: [6]byte{0, 9, 0, 0, 0, 1}, - fragmentOffset: 1, - more: true, - id: 1, - }, - { - name: "Mixed", - bytes: [6]byte{68, 9, 128, 4, 2, 1}, - fragmentOffset: 2177, - more: true, - id: 2147746305, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - extHdr := IPv6FragmentExtHdr(test.bytes) - if got := extHdr.FragmentOffset(); got != test.fragmentOffset { - t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset) - } - if got := extHdr.More(); got != test.more { - t.Errorf("got More() = %t, want = %t", got, test.more) - } - if got := extHdr.ID(); got != test.id { - t.Errorf("got ID() = %d, want = %d", got, test.id) - } - }) - } -} - -func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView { - size := 0 - var vs []buffer.View - - for _, b := range bs { - vs = append(vs, buffer.View(b)) - size += len(b) - } - - return buffer.NewVectorisedView(size, vs) -} - -func TestIPv6ExtHdrIterErr(t *testing.T) { - tests := []struct { - name string - firstNextHdr IPv6ExtensionHeaderIdentifier - payload buffer.VectorisedView - err error - }{ - { - name: "Upper layer only without data", - firstNextHdr: 255, - }, - { - name: "Upper layer only with data", - firstNextHdr: 255, - payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), - }, - { - name: "No next header", - firstNextHdr: IPv6NoNextHeaderIdentifier, - }, - { - name: "No next header with data", - firstNextHdr: IPv6NoNextHeaderIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), - }, - { - name: "Valid single hop by hop", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}), - }, - { - name: "Hop by hop too small", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid single fragment", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}), - }, - { - name: "Fragment too small", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid single destination", - firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}), - }, - { - name: "Destination too small", - firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid single routing", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}), - }, - { - name: "Valid single routing across views", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}), - }, - { - name: "Routing too small with zero length field", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid routing with non-zero length field", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}), - }, - { - name: "Valid routing with non-zero length field across views", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}), - }, - { - name: "Routing too small with non-zero length field", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Routing too small with non-zero length field across views", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Mixed", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // (Atomic) Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 255, 4, 1, 2, 3, 4, - - // Upper layer data. - 1, 2, 3, 4, - }), - }, - { - name: "Mixed without upper layer data", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // (Atomic) Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 255, 4, 1, 2, 3, 4, - }), - }, - { - name: "Mixed without upper layer data but last ext hdr too small", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // (Atomic) Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 255, 4, 1, 2, 3, - }), - err: io.ErrUnexpectedEOF, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload) - - for i := 0; ; i++ { - _, done, err := it.Next() - if err != nil { - // If we encountered a non-nil error while iterating, make sure it is - // is the same error as test.err. - if !errors.Is(err, test.err) { - t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err) - } - - return - } - if done { - // If we are done (without an error), make sure that we did not expect - // an error. - if test.err != nil { - t.Fatalf("expected error when iterating; want = %s", test.err) - } - - return - } - } - }) - } -} - -func TestIPv6ExtHdrIter(t *testing.T) { - routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4}) - upperLayerData := buffer.View([]byte{1, 2, 3, 4}) - tests := []struct { - name string - firstNextHdr IPv6ExtensionHeaderIdentifier - payload buffer.VectorisedView - expected []IPv6PayloadHeader - }{ - // With a non-atomic fragment that is not the first fragment, the payload - // after the fragment will not be parsed because the payload is expected to - // only hold upper layer data. - { - name: "hopbyhop - fragment (not first) - routing - upper", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Fragment extension header. - // - // More = 1, Fragment Offset = 2117, ID = 2147746305 - uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, - - // Routing extension header. - // - // Even though we have a routing ext header here, it should be - // be interpretted as raw bytes as only the first fragment is expected - // to hold headers. - 255, 0, 1, 2, 3, 4, 5, 6, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), - IPv6RawPayloadHeader{ - Identifier: IPv6RoutingExtHdrIdentifier, - Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), - }, - }, - }, - { - name: "hopbyhop - fragment (first) - routing - upper", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Fragment extension header. - // - // More = 1, Fragment Offset = 0, ID = 2147746305 - uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1, - - // Routing extension header. - 255, 0, 1, 2, 3, 4, 5, 6, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}), - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6RawPayloadHeader{ - Identifier: 255, - Buf: upperLayerData.ToVectorisedView(), - }, - }, - }, - { - name: "fragment - routing - upper (across views)", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, - - // Routing extension header. - 255, 0, 1, 2}, []byte{3, 4, 5, 6, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), - IPv6RawPayloadHeader{ - Identifier: IPv6RoutingExtHdrIdentifier, - Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), - }, - }, - }, - - // If we have an atomic fragment, the payload following the fragment - // extension header should be parsed normally. - { - name: "atomic fragment - routing - destination - upper", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 1, 4, 1, 2, 3, 4, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6RawPayloadHeader{ - Identifier: 255, - Buf: upperLayerData.ToVectorisedView(), - }, - }, - }, - { - name: "atomic fragment - routing - upper (across views)", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, - - // Routing extension header. - 255, 0, 1, 2}, []byte{3, 4, 5, 6, - - // Upper layer data. - 1, 2}, []byte{3, 4}), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6RawPayloadHeader{ - Identifier: 255, - Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), - }, - }, - }, - { - name: "atomic fragment - destination - no next header", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - // - // Res (Reserved) bits are 1 which should not affect anything. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1, - - // Destination Options extension header. - uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - }, - }, - { - name: "routing - atomic fragment - no next header", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Routing extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - }, - }, - { - name: "routing - atomic fragment - no next header (across views)", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Routing extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - }, - }, - { - name: "hopbyhop - routing - fragment - no next header", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Routing extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Fragment extension header. - // - // Fragment Offset = 32; Res = 6. - uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}), - IPv6RawPayloadHeader{ - Identifier: IPv6NoNextHeaderIdentifier, - Buf: upperLayerData.ToVectorisedView(), - }, - }, - }, - - // Test the raw payload for common transport layer protocol numbers. - { - name: "TCP raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "UDP raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "ICMPv4 raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "ICMPv6 raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "Unknwon next header raw payload", - firstNextHdr: 255, - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: 255, - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "Unknwon next header raw payload (across views)", - firstNextHdr: 255, - payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: 255, - Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), - }}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload) - - for i, e := range test.expected { - extHdr, done, err := it.Next() - if err != nil { - t.Errorf("(i=%d) Next(): %s", i, err) - } - if done { - t.Errorf("(i=%d) unexpectedly done iterating", i) - } - if diff := cmp.Diff(e, extHdr); diff != "" { - t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff) - } - - if t.Failed() { - t.FailNow() - } - } - - extHdr, done, err := it.Next() - if err != nil { - t.Errorf("(last) Next(): %s", err) - } - if !done { - t.Errorf("(last) iterator unexpectedly not done") - } - if extHdr != nil { - t.Errorf("(last) got Next() = %T, want = nil", extHdr) - } - }) - } -} - -var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil) - -// dummyHbHOptionSerializer provides a generic implementation of -// IPv6SerializableHopByHopOption for use in tests. -type dummyHbHOptionSerializer struct { - id IPv6ExtHdrOptionIdentifier - payload []byte - align int - alignOffset int -} - -// identifier implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier { - return s.id -} - -// length implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) length() uint8 { - return uint8(len(s.payload)) -} - -// alignment implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) alignment() (int, int) { - align := 1 - if s.align != 0 { - align = s.align - } - return align, s.alignOffset -} - -// serializeInto implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 { - return uint8(copy(b, s.payload)) -} - -func TestIPv6HopByHopSerializer(t *testing.T) { - validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { - t.Helper() - dummy, ok := serializable.(*dummyHbHOptionSerializer) - if !ok { - t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable) - } - unknown, ok := deserialized.(*IPv6UnknownExtHdrOption) - if !ok { - t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{}) - } - if dummy.id != unknown.Identifier { - t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id) - } - if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" { - t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff) - } - } - tests := []struct { - name string - nextHeader uint8 - options []IPv6SerializableHopByHopOption - expect []byte - validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption) - }{ - { - name: "single option", - nextHeader: 13, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 15, - payload: []byte{9, 8, 7, 6}, - }, - }, - expect: []byte{13, 0, 15, 4, 9, 8, 7, 6}, - validate: validateDummies, - }, - { - name: "short option padN zero", - nextHeader: 88, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5}, - }, - }, - expect: []byte{88, 0, 22, 2, 4, 5, 1, 0}, - validate: validateDummies, - }, - { - name: "short option pad1", - nextHeader: 11, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 33, - payload: []byte{1, 2, 3}, - }, - }, - expect: []byte{11, 0, 33, 3, 1, 2, 3, 0}, - validate: validateDummies, - }, - { - name: "long option padN", - nextHeader: 55, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 77, - payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, - }, - }, - expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0}, - validate: validateDummies, - }, - { - name: "two options", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 11, - payload: []byte{1, 2, 3}, - }, - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5, 6}, - }, - }, - expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0}, - validate: validateDummies, - }, - { - name: "two options align 2n", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 11, - payload: []byte{1, 2, 3}, - }, - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5, 6}, - align: 2, - }, - }, - expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0}, - validate: validateDummies, - }, - { - name: "two options align 8n+1", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 11, - payload: []byte{1, 2}, - }, - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5, 6}, - align: 8, - alignOffset: 1, - }, - }, - expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0}, - validate: validateDummies, - }, - { - name: "no options", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{}, - expect: []byte{33, 0, 1, 4, 0, 0, 0, 0}, - }, - { - name: "Router Alert", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}}, - expect: []byte{33, 0, 5, 2, 0, 0, 1, 0}, - validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { - t.Helper() - routerAlert, ok := deserialized.(*IPv6RouterAlertOption) - if !ok { - t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized) - } - if routerAlert.Value != IPv6RouterAlertMLD { - t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD) - } - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := IPv6SerializableHopByHopExtHdr(test.options) - length := s.length() - if length != len(test.expect) { - t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect)) - } - b := make([]byte, length) - for i := range b { - // Fill the buffer with ones to ensure all padding is correctly set. - b[i] = 0xFF - } - if got := s.serializeInto(test.nextHeader, b); got != length { - t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length) - } - if diff := cmp.Diff(test.expect, b); diff != "" { - t.Fatalf("serialization mismatch (-want +got):\n%s", diff) - } - - // Deserialize the options and verify them. - optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit - iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter() - for _, testOpt := range test.options { - opt, done, err := iter.Next() - if err != nil { - t.Fatalf("iter.Next(): %s", err) - } - if done { - t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done) - } - test.validate(t, testOpt, opt) - } - opt, done, err := iter.Next() - if err != nil { - t.Fatalf("iter.Next(): %s", err) - } - if !done { - t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done) - } - }) - } -} - -var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil) - -// dummyIPv6ExtHdrSerializer provides a generic implementation of -// IPv6SerializableExtHdr for use in tests. -// -// The dummy header always carries the nextHeader value in the first byte. -type dummyIPv6ExtHdrSerializer struct { - id IPv6ExtensionHeaderIdentifier - headerContents []byte -} - -// identifier implements IPv6SerializableExtHdr. -func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier { - return s.id -} - -// length implements IPv6SerializableExtHdr. -func (s *dummyIPv6ExtHdrSerializer) length() int { - return len(s.headerContents) + 1 -} - -// serializeInto implements IPv6SerializableExtHdr. -func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int { - b[0] = nextHeader - return copy(b[1:], s.headerContents) + 1 -} - -func TestIPv6ExtHdrSerializer(t *testing.T) { - tests := []struct { - name string - headers []IPv6SerializableExtHdr - nextHeader tcpip.TransportProtocolNumber - expectSerialized []byte - expectNextHeader uint8 - }{ - { - name: "one header", - headers: []IPv6SerializableExtHdr{ - &dummyIPv6ExtHdrSerializer{ - id: 15, - headerContents: []byte{1, 2, 3, 4}, - }, - }, - nextHeader: TCPProtocolNumber, - expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4}, - expectNextHeader: 15, - }, - { - name: "two headers", - headers: []IPv6SerializableExtHdr{ - &dummyIPv6ExtHdrSerializer{ - id: 22, - headerContents: []byte{1, 2, 3}, - }, - &dummyIPv6ExtHdrSerializer{ - id: 23, - headerContents: []byte{4, 5, 6}, - }, - }, - nextHeader: ICMPv6ProtocolNumber, - expectSerialized: []byte{ - 23, 1, 2, 3, - byte(ICMPv6ProtocolNumber), 4, 5, 6, - }, - expectNextHeader: 22, - }, - { - name: "no headers", - headers: []IPv6SerializableExtHdr{}, - nextHeader: UDPProtocolNumber, - expectSerialized: []byte{}, - expectNextHeader: byte(UDPProtocolNumber), - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := IPv6ExtHdrSerializer(test.headers) - l := s.Length() - if got, want := l, len(test.expectSerialized); got != want { - t.Fatalf("got serialized length = %d, want = %d", got, want) - } - b := make([]byte, l) - for i := range b { - // Fill the buffer with garbage to make sure we're writing to all bytes. - b[i] = 0xFF - } - nextHeader, serializedLen := s.Serialize(test.nextHeader, b) - if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader { - t.Errorf( - "got s.Serialize(..) = (%d, %d), want = (%d, %d)", - nextHeader, - serializedLen, - test.expectNextHeader, - len(test.expectSerialized), - ) - } - if diff := cmp.Diff(test.expectSerialized, b); diff != "" { - t.Errorf("serialization mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go deleted file mode 100644 index f10f446a6..000000000 --- a/pkg/tcpip/header/ipv6_test.go +++ /dev/null @@ -1,375 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "bytes" - "crypto/sha256" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") -) - -func TestEthernetAdddressToModifiedEUI64(t *testing.T) { - expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} - - if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { - t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) - } - - var buf [header.IIDSize]byte - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) - if diff := cmp.Diff(expectedIID, buf); diff != "" { - t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) - } -} - -func TestLinkLocalAddr(t *testing.T) { - if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { - t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) - } -} - -func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte - if n, err := rand.Read(secretKeyBuf[:]); err != nil { - t.Fatalf("rand.Read(_): %s", err) - } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) - } - - tests := []struct { - name string - prefix tcpip.Subnet - nicName string - dadCounter uint8 - secretKey []byte - }{ - { - name: "SecretKey of minimum size", - prefix: header.IPv6LinkLocalPrefix.Subnet(), - nicName: "eth0", - dadCounter: 0, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], - }, - { - name: "SecretKey of less than minimum size", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "eth10", - dadCounter: 1, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], - }, - { - name: "SecretKey of more than minimum size", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "eth11", - dadCounter: 2, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], - }, - { - name: "Nil SecretKey and empty nicName", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "", - dadCounter: 3, - secretKey: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - h := sha256.New() - h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) - h.Write([]byte(test.nicName)) - h.Write([]byte{test.dadCounter}) - if k := test.secretKey; k != nil { - h.Write(k) - } - var hashSum [sha256.Size]byte - h.Sum(hashSum[:0]) - want := hashSum[:header.IIDSize] - - // Passing a nil buffer should result in a new buffer returned with the - // IID. - if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { - t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) - } - - // Passing a buffer with sufficient capacity for the IID should populate - // the buffer provided. - var iidBuf [header.IIDSize]byte - if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { - t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) - } - if got := iidBuf[:]; !bytes.Equal(got, want) { - t.Errorf("got iidBuf = %x, want = %x", got, want) - } - }) - } -} - -func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte - if n, err := rand.Read(secretKeyBuf[:]); err != nil { - t.Fatalf("rand.Read(_): %s", err) - } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) - } - - prefix := header.IPv6LinkLocalPrefix.Subnet() - - tests := []struct { - name string - prefix tcpip.Subnet - nicName string - dadCounter uint8 - secretKey []byte - }{ - { - name: "SecretKey of minimum size", - nicName: "eth0", - dadCounter: 0, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], - }, - { - name: "SecretKey of less than minimum size", - nicName: "eth10", - dadCounter: 1, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], - }, - { - name: "SecretKey of more than minimum size", - nicName: "eth11", - dadCounter: 2, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], - }, - { - name: "Nil SecretKey and empty nicName", - nicName: "", - dadCounter: 3, - secretKey: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - addrBytes := [header.IPv6AddressSize]byte{ - 0: 0xFE, - 1: 0x80, - } - - want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( - addrBytes[:header.IIDOffsetInIPv6Address], - prefix, - test.nicName, - test.dadCounter, - test.secretKey, - )) - - if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { - t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) - } - }) - } -} - -func TestIsV6LinkLocalMulticastAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Link Local Multicast", - addr: linkLocalMulticastAddr, - expected: true, - }, - { - name: "Valid Link Local Multicast with flags", - addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - expected: true, - }, - { - name: "Link Local Unicast", - addr: linkLocalAddr, - expected: false, - }, - { - name: "IPv4 Multicast", - addr: "\xe0\x00\x00\x01", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestIsV6LinkLocalAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Link Local Unicast", - addr: linkLocalAddr, - expected: true, - }, - { - name: "Link Local Multicast", - addr: linkLocalMulticastAddr, - expected: false, - }, - { - name: "Unique Local", - addr: uniqueLocalAddr1, - expected: false, - }, - { - name: "Global", - addr: globalAddr, - expected: false, - }, - { - name: "IPv4 Link Local", - addr: "\xa9\xfe\x00\x01", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestScopeForIPv6Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - scope header.IPv6AddressScope - err tcpip.Error - }{ - { - name: "Unique Local", - addr: uniqueLocalAddr1, - scope: header.GlobalScope, - err: nil, - }, - { - name: "Link Local Unicast", - addr: linkLocalAddr, - scope: header.LinkLocalScope, - err: nil, - }, - { - name: "Link Local Multicast", - addr: linkLocalMulticastAddr, - scope: header.LinkLocalScope, - err: nil, - }, - { - name: "Global", - addr: globalAddr, - scope: header.GlobalScope, - err: nil, - }, - { - name: "IPv4", - addr: "\x01\x02\x03\x04", - scope: header.GlobalScope, - err: &tcpip.ErrBadAddress{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := header.ScopeForIPv6Address(test.addr) - if diff := cmp.Diff(test.err, err); diff != "" { - t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff) - } - if got != test.scope { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) - } - }) - } -} - -func TestSolicitedNodeAddr(t *testing.T) { - tests := []struct { - addr tcpip.Address - want tcpip.Address - }{ - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", - }, - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", - }, - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03", - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { - if got := header.SolicitedNodeAddr(test.addr); got != test.want { - t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want) - } - }) - } -} diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go deleted file mode 100644 index b5540bf66..000000000 --- a/pkg/tcpip/header/ipversion_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestIPv4(t *testing.T) { - b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) - - const want = header.IPv4Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestIPv6(t *testing.T) { - b := header.IPv6(make([]byte, header.IPv6MinimumSize)) - b.Encode(&header.IPv6Fields{}) - - const want = header.IPv6Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestOtherVersion(t *testing.T) { - const want = header.IPv4Version + header.IPv6Version - b := make([]byte, 1) - b[0] = want << 4 - - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestTooShort(t *testing.T) { - b := make([]byte, 1) - b[0] = (header.IPv4Version + header.IPv6Version) << 4 - - // Get the version of a zero-length slice. - const want = -1 - if v := header.IPVersion(b[:0]); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } - - // Get the version of a nil slice. - if v := header.IPVersion(nil); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go deleted file mode 100644 index 0cecf10d4..000000000 --- a/pkg/tcpip/header/mld_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "encoding/binary" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestMLD(t *testing.T) { - b := []byte{ - // Maximum Response Delay - 0, 0, - - // Reserved - 0, 0, - - // MulticastAddress - 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, - } - - const maxRespDelay = 513 - binary.BigEndian.PutUint16(b, maxRespDelay) - - mld := MLD(b) - - if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want { - t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) - } - - const newMaxRespDelay = 1234 - mld.SetMaximumResponseDelay(newMaxRespDelay) - if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want { - t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) - } - - if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want { - t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want) - } - - multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}) - mld.SetMulticastAddress(multicastAddress) - if got := mld.MulticastAddress(); got != multicastAddress { - t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress) - } -} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go deleted file mode 100644 index dc4591253..000000000 --- a/pkg/tcpip/header/ndp_test.go +++ /dev/null @@ -1,1521 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "bytes" - "errors" - "fmt" - "io" - "regexp" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" -) - -// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. -func TestNDPNeighborSolicit(t *testing.T) { - b := []byte{ - 0, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - ns := NDPNeighborSolicit(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") - if got := ns.TargetAddress(); got != addr { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) - } - - // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") - ns.SetTargetAddress(addr2) - if got := ns.TargetAddress(); got != addr2 { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } -} - -// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. -func TestNDPNeighborAdvert(t *testing.T) { - b := []byte{ - 160, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - na := NDPNeighborAdvert(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") - if got := na.TargetAddress(); got != addr { - t.Errorf("got TargetAddress = %s, want %s", got, addr) - } - - // Test getting the Router Flag. - if got := na.RouterFlag(); !got { - t.Errorf("got RouterFlag = false, want = true") - } - - // Test getting the Solicited Flag. - if got := na.SolicitedFlag(); got { - t.Errorf("got SolicitedFlag = true, want = false") - } - - // Test getting the Override Flag. - if got := na.OverrideFlag(); !got { - t.Errorf("got OverrideFlag = false, want = true") - } - - // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") - na.SetTargetAddress(addr2) - if got := na.TargetAddress(); got != addr2 { - t.Errorf("got TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } - - // Test updating the Router Flag. - na.SetRouterFlag(false) - if got := na.RouterFlag(); got { - t.Errorf("got RouterFlag = true, want = false") - } - - // Test updating the Solicited Flag. - na.SetSolicitedFlag(true) - if got := na.SolicitedFlag(); !got { - t.Errorf("got SolicitedFlag = false, want = true") - } - - // Test updating the Override Flag. - na.SetOverrideFlag(false) - if got := na.OverrideFlag(); got { - t.Errorf("got OverrideFlag = true, want = false") - } - - // Make sure flags got updated in the backing buffer. - if got := b[ndpNAFlagsOffset]; got != 64 { - t.Errorf("got flags byte = %d, want = 64", got) - } -} - -func TestNDPRouterAdvert(t *testing.T) { - b := []byte{ - 64, 128, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, - } - - ra := NDPRouterAdvert(b) - - if got := ra.CurrHopLimit(); got != 64 { - t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) - } - - if got := ra.ManagedAddrConfFlag(); !got { - t.Errorf("got ManagedAddrConfFlag = false, want = true") - } - - if got := ra.OtherConfFlag(); got { - t.Errorf("got OtherConfFlag = true, want = false") - } - - if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) - } - - if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) - } - - if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) - } -} - -// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the -// Ethernet address from an NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { - tests := []struct { - name string - buf []byte - expected tcpip.LinkAddress - }{ - { - "ValidMAC", - []byte{1, 2, 3, 4, 5, 6}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - { - "SLLBodyTooShort", - []byte{1, 2, 3, 4, 5}, - tcpip.LinkAddress([]byte(nil)), - }, - { - "SLLBodyLargerThanNeeded", - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - sll := NDPSourceLinkLayerAddressOption(test.buf) - if got := sll.EthernetAddress(); got != test.expected { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected) - } - }) - } -} - -// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a -// NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - nil, - nil, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPSourceLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) - } - sll := next.(NDPSourceLinkLayerAddressOption) - if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the -// Ethernet address from an NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { - tests := []struct { - name string - buf []byte - expected tcpip.LinkAddress - }{ - { - "ValidMAC", - []byte{1, 2, 3, 4, 5, 6}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - { - "TLLBodyTooShort", - []byte{1, 2, 3, 4, 5}, - tcpip.LinkAddress([]byte(nil)), - }, - { - "TLLBodyLargerThanNeeded", - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tll := NDPTargetLinkLayerAddressOption(test.buf) - if got := tll.EthernetAddress(); got != test.expected { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) - } - }) - } -} - -// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a -// NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - nil, - nil, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPTargetLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - tll := next.(NDPTargetLinkLayerAddressOption) - if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPPrefixInformationOption tests the field getters and serialization of a -// NDPPrefixInformation. -func TestNDPPrefixInformationOption(t *testing.T) { - b := []byte{ - 43, 127, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 5, 5, 5, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPPrefixInformation(b), - } - opts.Serialize(serializer) - expectedBuf := []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - if !bytes.Equal(targetBuf, expectedBuf) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - pi := next.(NDPPrefixInformation) - - if got := pi.Type(); got != 3 { - t.Errorf("got Type = %d, want = 3", got) - } - - if got := pi.Length(); got != 30 { - t.Errorf("got Length = %d, want = 30", got) - } - - if got := pi.PrefixLength(); got != 43 { - t.Errorf("got PrefixLength = %d, want = 43", got) - } - - if pi.OnLinkFlag() { - t.Error("got OnLinkFlag = true, want = false") - } - - if !pi.AutonomousAddressConfigurationFlag() { - t.Error("got AutonomousAddressConfigurationFlag = false, want = true") - } - - if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { - t.Errorf("got ValidLifetime = %d, want = %d", got, want) - } - - if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { - t.Errorf("got PreferredLifetime = %d, want = %d", got, want) - } - - if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { - t.Errorf("got Prefix = %s, want = %s", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 25, 3, 0, 0, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPRecursiveDNSServer(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Type(); got != 25 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16909320*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - want := []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - } - addrs, err := opt.Addresses() - if err != nil { - t.Errorf("opt.Addresses() = %s", err) - } - if diff := cmp.Diff(addrs, want); diff != "" { - t.Errorf("mismatched addresses (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOption(t *testing.T) { - tests := []struct { - name string - buf []byte - lifetime time.Duration - addrs []tcpip.Address - }{ - { - "Valid1Addr", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - }, - }, - { - "Valid2Addr", - []byte{ - 25, 5, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - }, - }, - { - "Valid3Addr", - []byte{ - 25, 7, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Iterator should get our option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Lifetime(); got != test.lifetime { - t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) - } - addrs, err := opt.Addresses() - if err != nil { - t.Errorf("opt.Addresses() = %s", err) - } - if diff := cmp.Diff(addrs, test.addrs); diff != "" { - t.Errorf("mismatched addresses (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList. -func TestNDPDNSSearchListOption(t *testing.T) { - tests := []struct { - name string - buf []byte - lifetime time.Duration - domainNames []string - err error - }{ - { - name: "Valid1Label", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, 'a', 'b', 'c', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: []string{ - "abc", - }, - err: nil, - }, - { - name: "Valid2Label", - buf: []byte{ - 0, 0, - 0, 0, 0, 5, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 0, - 0, 0, 0, 0, 0, 0, - }, - lifetime: 5 * time.Second, - domainNames: []string{ - "abc.abcd", - }, - err: nil, - }, - { - name: "Valid3Label", - buf: []byte{ - 0, 0, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - 0, 0, 0, 0, - }, - lifetime: 16777216 * time.Second, - domainNames: []string{ - "abc.abcd.e", - }, - err: nil, - }, - { - name: "Valid2Domains", - buf: []byte{ - 0, 0, - 1, 2, 3, 4, - 3, 'a', 'b', 'c', - 0, - 2, 'd', 'e', - 3, 'x', 'y', 'z', - 0, - 0, 0, 0, - }, - lifetime: 16909060 * time.Second, - domainNames: []string{ - "abc", - "de.xyz", - }, - err: nil, - }, - { - name: "Valid3DomainsMixedCase", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 3, 'a', 'B', 'c', - 0, - 2, 'd', 'E', - 3, 'X', 'y', 'z', - 0, - 1, 'J', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abc", - "de.xyz", - "j", - }, - err: nil, - }, - { - name: "ValidDomainAfterNULL", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 3, 'a', 'B', 'c', - 0, 0, 0, 0, - 2, 'd', 'E', - 3, 'X', 'y', 'z', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abc", - "de.xyz", - }, - err: nil, - }, - { - name: "Valid0Domains", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 0, - 0, 0, 0, 0, 0, 0, 0, - }, - lifetime: 0, - domainNames: nil, - err: nil, - }, - { - name: "NoTrailingNull", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g', - }, - lifetime: 0, - domainNames: nil, - err: io.ErrUnexpectedEOF, - }, - { - name: "IncorrectLength", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g', - }, - lifetime: 0, - domainNames: nil, - err: io.ErrUnexpectedEOF, - }, - { - name: "IncorrectLengthWithNULL", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 7, 'a', 'b', 'c', 'd', 'e', 'f', - 0, - }, - lifetime: 0, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "LabelOfLength63", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk", - }, - err: nil, - }, - { - name: "LabelOfLength64", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', - 0, - }, - lifetime: 0, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "DomainNameOfLength255", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij", - }, - err: nil, - }, - { - name: "DomainNameOfLength256", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 0, - }, - lifetime: 0, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "StartingDigitForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, '9', 'b', 'c', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "StartingHyphenForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, '-', 'b', 'c', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "EndingHyphenForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, 'a', 'b', '-', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "EndingDigitForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, 'a', 'b', '9', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: []string{ - "ab9", - }, - err: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opt := NDPDNSSearchList(test.buf) - - if got := opt.Lifetime(); got != test.lifetime { - t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) - } - domainNames, err := opt.DomainNames() - if !errors.Is(err, test.err) { - t.Errorf("opt.DomainNames() = %s", err) - } - if diff := cmp.Diff(domainNames, test.domainNames); diff != "" { - t.Errorf("mismatched domain names (-want +got):\n%s", diff) - } - }) - } -} - -func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { - for r := rune(0); r <= 255; r++ { - t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) { - buf := []byte{ - 0, 0, - 0, 0, 0, 0, - 3, 'a', 0 /* will be replaced */, 'c', - 0, - 0, 0, 0, - } - buf[8] = uint8(r) - opt := NDPDNSSearchList(buf) - - // As per RFC 1035 section 2.3.1, the label must only include ASCII - // letters, digits and hyphens (a-z, A-Z, 0-9, -). - var expectedErr error - re := regexp.MustCompile(`[a-zA-Z0-9-]`) - if !re.Match([]byte{byte(r)}) { - expectedErr = ErrNDPOptMalformedBody - } - - if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) { - t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody) - } - }) - } -} - -func TestNDPDNSSearchListOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 31, 3, 0, 0, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - 0, 0, 0, 0, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPDNSSearchList(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPDNSSearchListOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType) - } - - opt, ok := next.(NDPDNSSearchList) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next) - } - if got := opt.Type(); got != 31 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16777216*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - domainNames, err := opt.DomainNames() - if err != nil { - t.Errorf("opt.DomainNames() = %s", err) - } - if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" { - t.Errorf("domain names mismatch (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions -// the iterator was returned for is malformed. -func TestNDPOptionsIterCheck(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedErr error - }{ - { - name: "ZeroLengthField", - buf: []byte{0, 0, 0, 0, 0, 0, 0, 0}, - expectedErr: ErrNDPOptMalformedHeader, - }, - { - name: "ValidSourceLinkLayerAddressOption", - buf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, - expectedErr: nil, - }, - { - name: "TooSmallSourceLinkLayerAddressOption", - buf: []byte{1, 1, 1, 2, 3, 4, 5}, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "ValidTargetLinkLayerAddressOption", - buf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, - expectedErr: nil, - }, - { - name: "TooSmallTargetLinkLayerAddressOption", - buf: []byte{2, 1, 1, 2, 3, 4, 5}, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "ValidPrefixInformation", - buf: []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - expectedErr: nil, - }, - { - name: "TooSmallPrefixInformation", - buf: []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "InvalidPrefixInformationLength", - buf: []byte{ - 3, 3, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", - buf: []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - expectedErr: nil, - }, - { - name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", - buf: []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // 255 is an unrecognized type. If 255 ends up - // being the type for some recognized type, - // update 255 to some other unrecognized value. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - expectedErr: nil, - }, - { - name: "InvalidRecursiveDNSServerCutsOffAddress", - buf: []byte{ - 25, 4, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 1, 2, 3, 4, 5, 6, 7, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "InvalidRecursiveDNSServerInvalidLengthField", - buf: []byte{ - 25, 2, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "RecursiveDNSServerTooSmall", - buf: []byte{ - 25, 1, 0, 0, - 0, 0, 0, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "RecursiveDNSServerMulticast", - buf: []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "RecursiveDNSServerUnspecified", - buf: []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "DNSSearchListLargeCompliantRFC1035", - buf: []byte{ - 31, 33, 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', - 0, - }, - expectedErr: nil, - }, - { - name: "DNSSearchListNonCompliantRFC1035", - buf: []byte{ - 31, 33, 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "DNSSearchListValidSmall", - buf: []byte{ - 31, 2, 0, 0, - 0, 0, 0, 0, - 6, 'a', 'b', 'c', 'd', 'e', 'f', - 0, - }, - expectedErr: nil, - }, - { - name: "DNSSearchListTooSmall", - buf: []byte{ - 31, 1, 0, 0, - 0, 0, 0, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - - if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) { - t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr) - } - - // test.buf may be malformed but we chose not to check - // the iterator so it must return true. - if _, err := opts.Iter(false); err != nil { - t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) - } - }) - } -} - -// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, -// this test does not actually check any of the option's getters, it simply -// checks the option Type and Body. We have other tests that tests the option -// field gettings given an option body and don't need to duplicate those tests -// here. -func TestNDPOptionsIter(t *testing.T) { - buf := []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // 255 is an unrecognized type. If 255 ends up being the type - // for some recognized type, update 255 to some other - // unrecognized value. Note, this option should be skipped when - // iterating. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - opts := NDPOptions(buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Test the first (Source Link-Layer) option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) - } - - // Test the next (Target Link-Layer) option. - next, done, err = it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - - // Test the next (Prefix Information) option. - // Note, the unrecognized option should be skipped. - next, done, err = it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD deleted file mode 100644 index 2adee9288..000000000 --- a/pkg/tcpip/header/parse/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "parse", - srcs = ["parse.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/header/parse/parse_state_autogen.go b/pkg/tcpip/header/parse/parse_state_autogen.go new file mode 100644 index 000000000..ad047be32 --- /dev/null +++ b/pkg/tcpip/header/parse/parse_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package parse diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go deleted file mode 100644 index 72563837b..000000000 --- a/pkg/tcpip/header/tcp_test.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestEncodeSACKBlocks(t *testing.T) { - testCases := []struct { - sackBlocks []header.SACKBlock - want []header.SACKBlock - bufSize int - }{ - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 40, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, - 30, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}}, - 20, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}}, - 10, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - nil, - 8, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 60, - }, - } - for _, tc := range testCases { - b := make([]byte, tc.bufSize) - t.Logf("testing: %v", tc) - header.EncodeSACKBlocks(tc.sackBlocks, b) - opts := header.ParseTCPOptions(b) - if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want) - } - } -} - -func TestTCPParseOptions(t *testing.T) { - type tsOption struct { - tsVal uint32 - tsEcr uint32 - } - - generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte { - l := 0 - if tsOpt != nil { - l += 10 - } - if len(sackBlocks) != 0 { - l += len(sackBlocks)*8 + 2 - } - b := make([]byte, l) - offset := 0 - if tsOpt != nil { - offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b) - } - header.EncodeSACKBlocks(sackBlocks, b[offset:]) - return b - } - - testCases := []struct { - b []byte - want header.TCPOptions - }{ - // Trivial cases. - {nil, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test timestamp parsing. - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - - // Test malformed timestamp option. - {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test SACKBlock parsing. - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}}, - {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}}, - - // Test malformed SACK option. - {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}}, - - // Test Timestamp + SACK block parsing. - {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}}, - - // Test valid timestamp + malformed SACK block parsing. - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - } - for _, tc := range testCases { - if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want) - } - } -} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD deleted file mode 100644 index 973f06cbc..000000000 --- a/pkg/tcpip/link/channel/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "channel", - srcs = ["channel.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/channel/channel_state_autogen.go b/pkg/tcpip/link/channel/channel_state_autogen.go new file mode 100644 index 000000000..38c12a3bf --- /dev/null +++ b/pkg/tcpip/link/channel/channel_state_autogen.go @@ -0,0 +1,34 @@ +// automatically generated by stateify. + +package channel + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (n *NotificationHandle) StateTypeName() string { + return "pkg/tcpip/link/channel.NotificationHandle" +} + +func (n *NotificationHandle) StateFields() []string { + return []string{ + "n", + } +} + +func (n *NotificationHandle) beforeSave() {} + +func (n *NotificationHandle) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.n) +} + +func (n *NotificationHandle) afterLoad() {} + +func (n *NotificationHandle) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.n) +} + +func init() { + state.Register((*NotificationHandle)(nil)) +} diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD deleted file mode 100644 index 0ae0d201a..000000000 --- a/pkg/tcpip/link/ethernet/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ethernet", - srcs = ["ethernet.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ethernet_test", - size = "small", - srcs = ["ethernet_test.go"], - deps = [ - ":ethernet", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/ethernet/ethernet_state_autogen.go b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go new file mode 100644 index 000000000..71d255c20 --- /dev/null +++ b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ethernet diff --git a/pkg/tcpip/link/ethernet/ethernet_test.go b/pkg/tcpip/link/ethernet/ethernet_test.go deleted file mode 100644 index 08a7f1ce1..000000000 --- a/pkg/tcpip/link/ethernet/ethernet_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ethernet_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var _ stack.NetworkDispatcher = (*testNetworkDispatcher)(nil) - -type testNetworkDispatcher struct { - networkPackets int -} - -func (t *testNetworkDispatcher) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { - t.networkPackets++ -} - -func (*testNetworkDispatcher) DeliverOutboundPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - -func TestDeliverNetworkPacket(t *testing.T) { - const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - otherLinkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - otherLinkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - ) - - e := ethernet.New(channel.New(0, 0, linkAddr)) - var networkDispatcher testNetworkDispatcher - e.Attach(&networkDispatcher) - - if networkDispatcher.networkPackets != 0 { - t.Fatalf("got networkDispatcher.networkPackets = %d, want = 0", networkDispatcher.networkPackets) - } - - // An ethernet frame with a destination link address that is not assigned to - // our ethernet link endpoint should still be delivered to the network - // dispatcher since the ethernet endpoint is not expected to filter frames. - eth := buffer.NewView(header.EthernetMinimumSize) - header.Ethernet(eth).Encode(&header.EthernetFields{ - SrcAddr: otherLinkAddr1, - DstAddr: otherLinkAddr2, - Type: header.IPv4ProtocolNumber, - }) - e.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: eth.ToVectorisedView(), - })) - if networkDispatcher.networkPackets != 1 { - t.Fatalf("got networkDispatcher.networkPackets = %d, want = 1", networkDispatcher.networkPackets) - } -} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD deleted file mode 100644 index ae1394ebf..000000000 --- a/pkg/tcpip/link/fdbased/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fdbased", - srcs = [ - "endpoint.go", - "endpoint_unsafe.go", - "mmap.go", - "mmap_stub.go", - "mmap_unsafe.go", - "packet_dispatchers.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/binary", - "//pkg/iovec", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fdbased_test", - size = "small", - srcs = ["endpoint_test.go"], - library = ":fdbased", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go deleted file mode 100644 index e82371798..000000000 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ /dev/null @@ -1,624 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package fdbased - -import ( - "bytes" - "fmt" - "math/rand" - "reflect" - "syscall" - "testing" - "time" - "unsafe" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - mtu = 1500 - laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") - raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") - proto = 10 - csumOffset = 48 - gsoMSS = 500 -) - -type packetInfo struct { - Raddr tcpip.LinkAddress - Proto tcpip.NetworkProtocolNumber - Contents *stack.PacketBuffer -} - -type packetContents struct { - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View - Data buffer.View -} - -func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { - t.Helper() - if diff := cmp.Diff( - want, got, - cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { - if pk == nil { - return nil - } - return &packetContents{ - LinkHeader: pk.LinkHeader().View(), - NetworkHeader: pk.NetworkHeader().View(), - TransportHeader: pk.TransportHeader().View(), - Data: pk.Data.ToView(), - } - }), - ); diff != "" { - t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) - } -} - -type context struct { - t *testing.T - readFDs []int - writeFDs []int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} -} - -func newContext(t *testing.T, opt *Options) *context { - firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) - if err != nil { - t.Fatalf("Socketpair failed: %v", err) - } - secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) - if err != nil { - t.Fatalf("Socketpair failed: %v", err) - } - - done := make(chan struct{}, 2) - opt.ClosedFunc = func(tcpip.Error) { - done <- struct{}{} - } - - opt.FDs = []int{firstFDPair[1], secondFDPair[1]} - ep, err := New(opt) - if err != nil { - t.Fatalf("Failed to create FD endpoint: %v", err) - } - - c := &context{ - t: t, - readFDs: []int{firstFDPair[0], secondFDPair[0]}, - writeFDs: opt.FDs, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, - } - - ep.Attach(c) - - return c -} - -func (c *context) cleanup() { - for _, fd := range c.readFDs { - syscall.Close(fd) - } - <-c.done - <-c.done - for _, fd := range c.writeFDs { - syscall.Close(fd) - } -} - -func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - c.ch <- packetInfo{remote, protocol, pkt} -} - -func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestNoEthernetProperties(t *testing.T) { - c := newContext(t, &Options{MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestEthernetProperties(t *testing.T) { - c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestAddress(t *testing.T) { - addrs := []tcpip.LinkAddress{"", "abc", "def"} - for _, a := range addrs { - t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { - c := newContext(t, &Options{Address: a, MTU: mtu}) - defer c.cleanup() - - if want, v := a, c.ep.LinkAddress(); want != v { - t.Fatalf("LinkAddress() = %v, want %v", v, want) - } - }) - } -} - -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) - defer c.cleanup() - - var r stack.RouteInfo - r.RemoteLinkAddress = raddr - - // Build payload. - payload := buffer.NewView(plen) - if _, err := rand.Read(payload); err != nil { - t.Fatalf("rand.Read(payload): %s", err) - } - - // Build packet buffer. - const netHdrLen = 100 - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, - Data: payload.ToVectorisedView(), - }) - pkt.Hash = hash - - // Build header. - b := pkt.NetworkHeader().Push(netHdrLen) - if _, err := rand.Read(b); err != nil { - t.Fatalf("rand.Read(b): %s", err) - } - - // Write. - want := append(append(buffer.View(nil), b...), payload...) - var gso *stack.GSO - if gsoMaxSize != 0 { - gso = &stack.GSO{ - Type: stack.GSOTCPv6, - NeedsCsum: true, - CsumOffset: csumOffset, - MSS: gsoMSS, - MaxSize: gsoMaxSize, - L3HdrLen: header.IPv4MaximumHeaderSize, - } - } - if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from the corresponding FD, then compare with what we wrote. - b = make([]byte, mtu) - fd := c.readFDs[hash%uint32(len(c.readFDs))] - n, err := syscall.Read(fd, b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - if gsoMaxSize != 0 { - vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) - if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { - t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) - } - csumStart := header.EthernetMinimumSize + gso.L3HdrLen - if vnetHdr.csumStart != csumStart { - t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) - } - if vnetHdr.csumOffset != csumOffset { - t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) - } - gsoType := uint8(0) - if int(gso.MSS) < plen { - gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - } - if vnetHdr.gsoType != gsoType { - t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) - } - b = b[virtioNetHdrSize:] - } - if eth { - h := header.Ethernet(b) - b = b[header.EthernetMinimumSize:] - - if a := h.SourceAddress(); a != laddr { - t.Fatalf("SourceAddress() = %v, want %v", a, laddr) - } - - if a := h.DestinationAddress(); a != raddr { - t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) - } - - if et := h.Type(); et != proto { - t.Fatalf("Type() = %v, want %v", et, proto) - } - } - if len(b) != len(want) { - t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) - } - if !bytes.Equal(b, want) { - t.Fatalf("Read returned %x, want %x", b, want) - } -} - -func TestWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} - gsos := []uint32{0, 32768} - - for _, eth := range eths { - for _, plen := range lengths { - for _, gso := range gsos { - t.Run( - fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), - func(t *testing.T) { - testWritePacket(t, plen, eth, gso, 0) - }, - ) - } - } - } -} - -func TestHashedWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} - gsos := []uint32{0, 32768} - hashes := []uint32{0, 1} - for _, eth := range eths { - for _, plen := range lengths { - for _, gso := range gsos { - for _, hash := range hashes { - t.Run( - fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash), - func(t *testing.T) { - testWritePacket(t, plen, eth, gso, hash) - }, - ) - } - } - } - } -} - -func TestPreserveSrcAddress(t *testing.T) { - baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") - - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true}) - defer c.cleanup() - - // Set LocalLinkAddress in route to the value of the bridged address. - var r stack.RouteInfo - r.LocalLinkAddress = baddr - r.RemoteLinkAddress = raddr - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). - ReserveHeaderBytes: header.EthernetMinimumSize, - Data: buffer.VectorisedView{}, - }) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from the FD, then compare with what we wrote. - b := make([]byte, mtu) - n, err := syscall.Read(c.readFDs[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - h := header.Ethernet(b) - - if a := h.SourceAddress(); a != baddr { - t.Fatalf("SourceAddress() = %v, want %v", a, baddr) - } -} - -func TestDeliverPacket(t *testing.T) { - lengths := []int{100, 1000} - eths := []bool{true, false} - - for _, eth := range eths { - for _, plen := range lengths { - t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) - defer c.cleanup() - - // Build packet. - all := make([]byte, plen) - if _, err := rand.Read(all); err != nil { - t.Fatalf("rand.Read(all): %s", err) - } - // Make it look like an IPv4 packet. - all[0] = 0x40 - - wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.EthernetMinimumSize, - Data: buffer.NewViewFromBytes(all).ToVectorisedView(), - }) - if eth { - hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) - hdr.Encode(&header.EthernetFields{ - SrcAddr: raddr, - DstAddr: laddr, - Type: proto, - }) - all = append(hdr, all...) - } - - // Write packet via the file descriptor. - if _, err := syscall.Write(c.readFDs[0], all); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Receive packet through the endpoint. - select { - case pi := <-c.ch: - want := packetInfo{ - Raddr: raddr, - Proto: proto, - Contents: wantPkt, - } - if !eth { - want.Proto = header.IPv4ProtocolNumber - want.Raddr = "" - } - checkPacketInfoEqual(t, pi, want) - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for packet") - } - }) - } - } -} - -func TestBufConfigMaxLength(t *testing.T) { - got := 0 - for _, i := range BufConfig { - got += i - } - want := header.MaxIPPacketSize // maximum TCP packet size - if got < want { - t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) - } -} - -func TestBufConfigFirst(t *testing.T) { - // The stack assumes that the TCP/IP header is enterily contained in the first view. - // Therefore, the first view needs to be large enough to contain the maximum TCP/IP - // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). - want := 120 - got := BufConfig[0] - if got < want { - t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) - } -} - -var capLengthTestCases = []struct { - comment string - config []int - n int - wantUsed int - wantLengths []int -}{ - { - comment: "Single slice", - config: []int{2}, - n: 1, - wantUsed: 1, - wantLengths: []int{1}, - }, - { - comment: "Multiple slices", - config: []int{1, 2}, - n: 2, - wantUsed: 2, - wantLengths: []int{1, 1}, - }, - { - comment: "Entire buffer", - config: []int{1, 2}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2}, - }, - { - comment: "Entire buffer but not on the last slice", - config: []int{1, 2, 3}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2}, - }, -} - -func TestIovecBuffer(t *testing.T) { - for _, c := range capLengthTestCases { - t.Run(c.comment, func(t *testing.T) { - b := newIovecBuffer(c.config, false /* skipsVnetHdr */) - - // Test initial allocation. - iovecs := b.nextIovecs() - if got, want := len(iovecs), len(c.config); got != want { - t.Fatalf("len(iovecs) = %d, want %d", got, want) - } - - // Make a copy as iovecs points to internal slice. We will need this state - // later. - oldIovecs := append([]syscall.Iovec(nil), iovecs...) - - // Test the views that get pulled. - vv := b.pullViews(c.n) - var lengths []int - for _, v := range vv.Views() { - lengths = append(lengths, len(v)) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths) - } - - // Test that new views get reallocated. - for i, newIov := range b.nextIovecs() { - if i < c.wantUsed { - if newIov.Base == oldIovecs[i].Base { - t.Errorf("b.views[%d] should have been reallocated", i) - } - } else { - if newIov.Base != oldIovecs[i].Base { - t.Errorf("b.views[%d] should not have been reallocated", i) - } - } - } - }) - } -} - -func TestIovecBufferSkipVnetHdr(t *testing.T) { - for _, test := range []struct { - desc string - readN int - wantLen int - }{ - { - desc: "nothing read", - readN: 0, - wantLen: 0, - }, - { - desc: "smaller than vnet header", - readN: virtioNetHdrSize - 1, - wantLen: 0, - }, - { - desc: "header skipped", - readN: virtioNetHdrSize + 100, - wantLen: 100, - }, - } { - t.Run(test.desc, func(t *testing.T) { - b := newIovecBuffer([]int{10, 20, 50, 50}, true) - // Pretend a read happend. - b.nextIovecs() - vv := b.pullViews(test.readN) - if got, want := vv.Size(), test.wantLen; got != want { - t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want) - } - if got, want := len(vv.ToOwnedView()), test.wantLen; got != want { - t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want) - } - }) - } -} - -// fakeNetworkDispatcher delivers packets to pkts. -type fakeNetworkDispatcher struct { - pkts []*stack.PacketBuffer -} - -func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - d.pkts = append(d.pkts, pkt) -} - -func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestDispatchPacketFormat(t *testing.T) { - for _, test := range []struct { - name string - newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) - }{ - { - name: "readVDispatcher", - newDispatcher: newReadVDispatcher, - }, - { - name: "recvMMsgDispatcher", - newDispatcher: newRecvMMsgDispatcher, - }, - } { - t.Run(test.name, func(t *testing.T) { - // Create a socket pair to send/recv. - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) - if err != nil { - t.Fatal(err) - } - defer syscall.Close(fds[0]) - defer syscall.Close(fds[1]) - - data := []byte{ - // Ethernet header. - 1, 2, 3, 4, 5, 60, - 1, 2, 3, 4, 5, 61, - 8, 0, - // Mock network header. - 40, 41, 42, 43, - } - err = syscall.Sendmsg(fds[1], data, nil, nil, 0) - if err != nil { - t.Fatal(err) - } - - // Create and run dispatcher once. - sink := &fakeNetworkDispatcher{} - d, err := test.newDispatcher(fds[0], &endpoint{ - hdrSize: header.EthernetMinimumSize, - dispatcher: sink, - }) - if err != nil { - t.Fatal(err) - } - if ok, err := d.dispatch(); !ok || err != nil { - t.Fatalf("d.dispatch() = %v, %v", ok, err) - } - - // Verify packet. - if got, want := len(sink.pkts), 1; got != want { - t.Fatalf("len(sink.pkts) = %d, want %d", got, want) - } - pkt := sink.pkts[0] - if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { - t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) - } - if got, want := pkt.Data.Size(), 4; got != want { - t.Errorf("pkt.Data.Size() = %d, want %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go new file mode 100644 index 000000000..b84e8f21c --- /dev/null +++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go @@ -0,0 +1,8 @@ +// automatically generated by stateify. + +// +build linux +// +build linux,amd64 linux,arm64 +// +build !linux !amd64,!arm64 +// +build linux + +package fdbased diff --git a/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go new file mode 100644 index 000000000..e2ed505b2 --- /dev/null +++ b/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build linux +// +build linux,amd64 linux,arm64 + +package fdbased diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD deleted file mode 100644 index 6bf3805b7..000000000 --- a/pkg/tcpip/link/loopback/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "loopback", - srcs = ["loopback.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go new file mode 100644 index 000000000..c00fd9f19 --- /dev/null +++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package loopback diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD deleted file mode 100644 index cbda59775..000000000 --- a/pkg/tcpip/link/muxed/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "muxed", - srcs = ["injectable.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "muxed_test", - size = "small", - srcs = ["injectable_test.go"], - library = ":muxed", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go deleted file mode 100644 index ba30287bc..000000000 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package muxed - -import ( - "bytes" - "net" - "os" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -func TestInjectableEndpointRawDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - endpoint.InjectOutbound(dstIP, []byte{0xFA}) - - buf := make([]byte, ipv4.MaxTotalSize) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: 1, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) - pkt.TransportHeader().Push(1)[0] = 0xFA - var packetRoute stack.RouteInfo - packetRoute.RemoteAddress = dstIP - - endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) - - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: 1, - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(1)[0] = 0xFA - var packetRoute stack.RouteInfo - packetRoute.RemoteAddress = dstIP - endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) { - dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4()) - pair, err := syscall.Socketpair(syscall.AF_UNIX, - syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0) - if err != nil { - t.Fatal("Failed to create socket pair:", err) - } - underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) - routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - endpoint := NewInjectableEndpoint(routes) - return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP -} diff --git a/pkg/tcpip/link/muxed/muxed_state_autogen.go b/pkg/tcpip/link/muxed/muxed_state_autogen.go new file mode 100644 index 000000000..56330e2a5 --- /dev/null +++ b/pkg/tcpip/link/muxed/muxed_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package muxed diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD deleted file mode 100644 index 00b42b924..000000000 --- a/pkg/tcpip/link/nested/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "nested", - srcs = [ - "nested.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "nested_test", - size = "small", - srcs = [ - "nested_test.go", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/nested/nested_state_autogen.go b/pkg/tcpip/link/nested/nested_state_autogen.go new file mode 100644 index 000000000..9e1b5ca4e --- /dev/null +++ b/pkg/tcpip/link/nested/nested_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package nested diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go deleted file mode 100644 index c1f9d308c..000000000 --- a/pkg/tcpip/link/nested/nested_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package nested_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/nested" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type parentEndpoint struct { - nested.Endpoint -} - -var _ stack.LinkEndpoint = (*parentEndpoint)(nil) -var _ stack.NetworkDispatcher = (*parentEndpoint)(nil) - -type childEndpoint struct { - stack.LinkEndpoint - dispatcher stack.NetworkDispatcher -} - -var _ stack.LinkEndpoint = (*childEndpoint)(nil) - -func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - c.dispatcher = dispatcher -} - -func (c *childEndpoint) IsAttached() bool { - return c.dispatcher != nil -} - -type counterDispatcher struct { - count int -} - -var _ stack.NetworkDispatcher = (*counterDispatcher)(nil) - -func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { - d.count++ -} - -func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestNestedLinkEndpoint(t *testing.T) { - const emptyAddress = tcpip.LinkAddress("") - - var ( - childEP childEndpoint - nestedEP parentEndpoint - disp counterDispatcher - ) - nestedEP.Endpoint.Init(&childEP, &nestedEP) - - if childEP.IsAttached() { - t.Error("On init, childEP.IsAttached() = true, want = false") - } - if nestedEP.IsAttached() { - t.Error("On init, nestedEP.IsAttached() = true, want = false") - } - - nestedEP.Attach(&disp) - if disp.count != 0 { - t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count) - } - if !childEP.IsAttached() { - t.Error("After attach, childEP.IsAttached() = false, want = true") - } - if !nestedEP.IsAttached() { - t.Error("After attach, nestedEP.IsAttached() = false, want = true") - } - - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if disp.count != 1 { - t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) - } - - nestedEP.Attach(nil) - if childEP.IsAttached() { - t.Error("After detach, childEP.IsAttached() = true, want = false") - } - if nestedEP.IsAttached() { - t.Error("After detach, nestedEP.IsAttached() = true, want = false") - } - - disp.count = 0 - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if disp.count != 0 { - t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) - } - -} diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD deleted file mode 100644 index 6fff160ce..000000000 --- a/pkg/tcpip/link/packetsocket/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "packetsocket", - srcs = ["endpoint.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/packetsocket/packetsocket_state_autogen.go b/pkg/tcpip/link/packetsocket/packetsocket_state_autogen.go new file mode 100644 index 000000000..6b3221fd8 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/packetsocket_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package packetsocket diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD deleted file mode 100644 index 9f31c1ffc..000000000 --- a/pkg/tcpip/link/pipe/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = ["pipe.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/pipe/pipe_state_autogen.go b/pkg/tcpip/link/pipe/pipe_state_autogen.go new file mode 100644 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/tcpip/link/pipe/pipe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD deleted file mode 100644 index 5bea598eb..000000000 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fifo", - srcs = [ - "endpoint.go", - "packet_buffer_queue.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go b/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go new file mode 100644 index 000000000..9eb52b1cb --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fifo diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD deleted file mode 100644 index e1047da50..000000000 --- a/pkg/tcpip/link/rawfile/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "rawfile", - srcs = [ - "blockingpoll_amd64.s", - "blockingpoll_arm64.s", - "blockingpoll_noyield_unsafe.go", - "blockingpoll_yield_unsafe.go", - "errors.go", - "rawfile_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "rawfile_test", - srcs = [ - "errors_test.go", - ], - library = "rawfile", - deps = [ - "//pkg/tcpip", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go deleted file mode 100644 index 61aea1744..000000000 --- a/pkg/tcpip/link/rawfile/errors_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package rawfile - -import ( - "syscall" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestTranslateErrno(t *testing.T) { - for _, test := range []struct { - errno syscall.Errno - translated tcpip.Error - }{ - { - errno: syscall.Errno(0), - translated: &tcpip.ErrInvalidEndpointState{}, - }, - { - errno: syscall.Errno(maxErrno), - translated: &tcpip.ErrInvalidEndpointState{}, - }, - { - errno: syscall.Errno(514), - translated: &tcpip.ErrInvalidEndpointState{}, - }, - { - errno: syscall.EEXIST, - translated: &tcpip.ErrDuplicateAddress{}, - }, - } { - got := TranslateErrno(test.errno) - if diff := cmp.Diff(test.translated, got); diff != "" { - t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff) - } - } -} diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go new file mode 100644 index 000000000..338e9679b --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build linux + +package rawfile diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go new file mode 100644 index 000000000..239d165f0 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go @@ -0,0 +1,9 @@ +// automatically generated by stateify. + +// +build linux,!amd64,!arm64 +// +build linux,amd64 linux,arm64 +// +build go1.12 +// +build !go1.17 +// +build linux + +package rawfile diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD deleted file mode 100644 index 13243ebbb..000000000 --- a/pkg/tcpip/link/sharedmem/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sharedmem", - srcs = [ - "rx.go", - "sharedmem.go", - "sharedmem_unsafe.go", - "tx.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], - library = ":sharedmem", - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/sharedmem/pipe", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD deleted file mode 100644 index 87020ec08..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = [ - "pipe.go", - "pipe_unsafe.go", - "rx.go", - "tx.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "pipe_test", - srcs = [ - "pipe_test.go", - ], - library = ":pipe", - deps = ["//pkg/sync"], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go new file mode 100644 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go deleted file mode 100644 index 2777f1411..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ /dev/null @@ -1,512 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "math/rand" - "reflect" - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestSimpleReadWrite(t *testing.T) { - // Check that a simple write can be properly read from the rx side. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - wb := tx.Push(10) - if wb == nil { - t.Fatalf("Push failed on empty pipe") - } - for i := range wb { - wb[i] = byte(tr.Intn(256)) - } - tx.Flush() - - var rx Rx - rx.Init(b) - rb := rx.Pull() - if len(rb) != 10 { - t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10) - } - - for i := range rb { - if v := byte(rr.Intn(256)); v != rb[i] { - t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v) - } - } - rx.Flush() -} - -func TestEmptyRead(t *testing.T) { - // Check that pulling from an empty pipe fails. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestTooLargeWrite(t *testing.T) { - // Check that writes that are too large are properly rejected. - b := make([]byte, 96) - var tx Tx - tx.Init(b) - - if wb := tx.Push(96); wb != nil { - t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(88); wb != nil { - t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } -} - -func TestFullWrite(t *testing.T) { - // Check that writes fail when the pipe is full. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestFullAndFlushedWrite(t *testing.T) { - // Check that writes fail when the pipe is full and has already been - // flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - tx.Flush() - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestTxFlushTwice(t *testing.T) { - // Checks that a second consecutive tx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // Make copy of original tx queue, flush it, then check that it didn't - // change. - orig := tx - tx.Flush() - - if !reflect.DeepEqual(orig, tx) { - t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig) - } -} - -func TestRxFlushTwice(t *testing.T) { - // Checks that a second consecutive rx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // Make copy of original rx queue, flush it, then check that it didn't - // change. - orig := rx - rx.Flush() - - if !reflect.DeepEqual(orig, rx) { - t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig) - } -} - -func TestWrapInMiddleOfTransaction(t *testing.T) { - // Check that writes are not flushed when we need to wrap the buffer - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestWriteAbort(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it but - // has aborted the push. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Abort() - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestWrappedWriteAbort(t *testing.T) { - // Check that writes are properly aborted even if the writes wrap - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Abort() - - // The pushes were aborted, so no data should be readable. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - // Try the same transactions again, but flush this time. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestEmptyReadOnNonFlushedWrite(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it - // but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Flush() - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull on failed on non-empty pipe") - } -} - -func TestPullAfterPullingEntirePipe(t *testing.T) { - // Check that Pull fails when the pipe is full, but all of it has - // already been pulled but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3 - // buffers that will fill the pipe. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - if wb := tx.Push(24); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The three buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - // Fourth pull must fail. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestNoRoomToWrapOnPush(t *testing.T) { - // Check that Push fails when it tries to allocate room to add a wrap - // message. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20, - // which won't fit (64+20+8+padding = 96, which wouldn't leave room for - // the padding), so it wraps around. - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - tx.Flush() - - // Buffer offset is at 28. Try to write 70, which would require a wrap - // slot which cannot be created now. - if wb := tx.Push(70); wb != nil { - t.Fatalf("Push succeeded on pipe with no room for wrap message") - } -} - -func TestRxImplicitFlushOfWrapMessage(t *testing.T) { - // Check if the first read is that of a wrapping message, that it gets - // immediately flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // This will cause a wrapping message to written. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - var rx Rx - rx.Init(b) - - // Read the first message. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // This should fail because of the wrapping message is taking up space. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - // Try to read the next one. This should consume the wrapping message. - rx.Pull() - - // This must now succeed. - if wb := tx.Push(60); wb == nil { - t.Fatalf("Push failed on empty pipe") - } -} - -func TestConcurrentReaderWriter(t *testing.T) { - // Push a million buffers of random sizes and random contents. Check - // that buffers read match what was written. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - - const count = 1000000 - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + tr.Intn(80) - wb := tx.Push(uint64(n)) - for wb == nil { - wb = tx.Push(uint64(n)) - } - - for j := range wb { - wb[j] = byte(tr.Intn(256)) - } - - tx.Flush() - } - }() - - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } - - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } - - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } - } - - rx.Flush() - } -} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go new file mode 100644 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD deleted file mode 100644 index 3ba06af73..000000000 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "queue", - srcs = [ - "rx.go", - "tx.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip/link/sharedmem/pipe", - ], -) - -go_test( - name = "queue_test", - srcs = [ - "queue_test.go", - ], - library = ":queue", - deps = [ - "//pkg/tcpip/link/sharedmem/pipe", - ], -) diff --git a/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go new file mode 100644 index 000000000..563d4fbb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package queue diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go deleted file mode 100644 index 9a0aad5d7..000000000 --- a/pkg/tcpip/link/sharedmem/queue/queue_test.go +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -func TestBasicTxQueue(t *testing.T) { - // Tests that a basic transmit on a queue works, and that completion - // gets properly reported as well. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Enqueue two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on empty queue") - } - - // Check the contents of the pipe. - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - - want := []byte{ - 234, 3, 0, 0, 0, 0, 0, 0, // id - 100, 0, 0, 0, // total size - 0, 0, 0, 0, // reserved - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - } - - if !reflect.DeepEqual(want, d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want) - } - - rxp.Flush() - - // Check that there are no completions yet. - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Packet reported as completed too soon") - } - - // Post a completion. - d = txp.Push(8) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - - // Check that completion is properly reported. - id, ok := q.CompletedPacket() - if !ok { - t.Fatalf("Completion not reported") - } - - if id != usedID { - t.Fatalf("Bad completion id: got %v, want %v", id, usedID) - } -} - -func TestBasicRxQueue(t *testing.T) { - // Tests that a basic receive on a queue works. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on empty queue") - } - - // Check the contents of the pipe. - want := [][]byte{ - { - 100, 0, 0, 0, 0, 0, 0, 0, // Offset1 - 60, 0, 0, 0, // Size1 - 0, 0, 0, 0, // Remaining in group 1 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - }, - { - 200, 0, 0, 0, 0, 0, 0, 0, // Offset2 - 40, 0, 0, 0, // Size2 - 0, 0, 0, 0, // Remaining in group 2 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }, - } - - for i := range b { - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - - if !reflect.DeepEqual(want[i], d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want[i]) - } - - rxp.Flush() - } - - // Check that there are no completions. - if _, n := q.Dequeue(nil); n != 0 { - t.Fatalf("Packet reported as received too soon") - } - - // Post a completion. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - // Check that completion is properly reported. - bufs, n := q.Dequeue(nil) - if n != 100 { - t.Fatalf("Bad packet size: got %v, want %v", n, 100) - } - - if !reflect.DeepEqual(bufs, b) { - t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b) - } -} - -func TestBadTxCompletion(t *testing.T) { - // Check that tx completions with bad sizes are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion that is too long, and check that it is ignored. - if d := txp.Push(10); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } -} - -func TestBadRxCompletion(t *testing.T) { - // Check that bad rx completions are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes add up to less than the total - // size. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 10, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 10, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes will cause a 32-bit overflow, - // but adds up to the right number. - d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 255, 255, 255, 255, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 101, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } -} - -func TestFillTxPipe(t *testing.T) { - // Check that transmitting a new buffer when the buffer pipe is full - // fails gracefully. - pb1 := make([]byte, 104) - pb2 := make([]byte, 104) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Transmit twice, which should fill the tx pipe. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - for i := uint64(0); i < 2; i++ { - if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) { - t.Fatalf("Failed to transmit buffer") - } - } - - // Transmit another packet now that the tx pipe is full. - if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue succeeded when tx pipe is full") - } -} - -func TestFillRxPipe(t *testing.T) { - // Check that posting a new buffer when the buffer pipe is full fails - // gracefully. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a buffer twice, it should fill the tx pipe. - b := []RxBuffer{ - {100, 60, 1077, 0}, - } - - for i := 0; i < 2; i++ { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - } - - // Post another buffer now that the tx pipe is full. - if q.PostBuffers(b) { - t.Fatalf("PostBuffers succeeded on full queue") - } -} - -func TestLotsOfTransmissions(t *testing.T) { - // Make sure pipes are being properly flushed when transmitting packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Prepare packet with two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - - // Post 100000 packets and completions. - for i := 100000; i > 0; i-- { - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - rxp.Flush() - - d := txp.Push(8) - if d == nil { - t.Fatalf("Unable to write to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - if _, ok := q.CompletedPacket(); !ok { - t.Fatalf("Completion not returned") - } - } -} - -func TestLotsOfReceptions(t *testing.T) { - // Make sure pipes are being properly flushed when receiving packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Prepare for posting two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - // Post 100000 buffers and completions. - for i := 100000; i > 0; i-- { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - if _, n := q.Dequeue(nil); n == 0 { - t.Fatalf("Dequeue failed when there is a completion") - } - } -} - -func TestRxEnableNotification(t *testing.T) { - // Check that enabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.EnableNotification() - if state != eventFDEnabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled) - } -} - -func TestRxDisableNotification(t *testing.T) { - // Check that disabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.DisableNotification() - if state != eventFDDisabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled) - } -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go new file mode 100644 index 000000000..bc12017b2 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +// +build linux +// +build linux + +package sharedmem diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go deleted file mode 100644 index d480ad656..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ /dev/null @@ -1,814 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package sharedmem - -import ( - "bytes" - "io/ioutil" - "math/rand" - "os" - "strings" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - localLinkAddr = "\xde\xad\xbe\xef\x56\x78" - remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" - - queueDataSize = 1024 * 1024 - queuePipeSize = 4096 -) - -type queueBuffers struct { - data []byte - rx pipe.Tx - tx pipe.Rx -} - -func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) { - // Prepare tx pipe. - b, err := getBuffer(c.TxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.tx.Init(b) - - // Prepare rx pipe. - b, err = getBuffer(c.RxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.rx.Init(b) - - // Get data slice. - q.data, err = getBuffer(c.DataFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } -} - -func (q *queueBuffers) cleanup() { - syscall.Munmap(q.tx.Bytes()) - syscall.Munmap(q.rx.Bytes()) - syscall.Munmap(q.data) -} - -type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView - linkHeader buffer.View -} - -type testContext struct { - t *testing.T - ep *endpoint - txCfg QueueConfig - rxCfg QueueConfig - txq queueBuffers - rxq queueBuffers - - packetCh chan struct{} - mu sync.Mutex - packets []packetInfo -} - -func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext { - var err error - c := &testContext{ - t: t, - packetCh: make(chan struct{}, 1000000), - } - c.txCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - c.rxCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - initQueue(t, &c.txq, &c.txCfg) - initQueue(t, &c.rxq, &c.rxCfg) - - ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - c.ep = ep.(*endpoint) - c.ep.Attach(c) - - return c -} - -func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - c.mu.Lock() - c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - vv: pkt.Data.Clone(nil), - }) - c.mu.Unlock() - - c.packetCh <- struct{}{} -} - -func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func (c *testContext) cleanup() { - c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) - c.txq.cleanup() - c.rxq.cleanup() -} - -func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) { - for i := 0; i < n; i++ { - select { - case <-c.packetCh: - case <-to: - c.t.Fatalf(errorStr) - } - } -} - -func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) { - b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs))) - queue.EncodeRxCompletion(b, size, 0) - for i := range bs { - queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{ - Offset: bs[i].Offset, - Size: bs[i].Size, - ID: bs[i].ID, - }) - } -} - -func randomFill(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func shuffle(b []int) { - for i := len(b) - 1; i >= 0; i-- { - j := rand.Intn(i + 1) - b[i], b[j] = b[j], b[i] - } -} - -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir, ok := os.LookupEnv("TEST_TMPDIR") - if !ok { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - syscall.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := syscall.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := syscall.Ftruncate(fd, size); err != nil { - syscall.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - syscall.Close(c.DataFD) - syscall.Close(c.EventFD) - syscall.Close(c.TxPipeFD) - syscall.Close(c.RxPipeFD) - syscall.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - -// TestSimpleSend sends 1000 packets with random header and payload sizes, -// then checks that the right payload is received on the shared memory queues. -func TestSimpleSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare route. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - for iters := 1000; iters > 0; iters-- { - func() { - hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) - - // Prepare and send packet. - hdrBuf := buffer.NewView(hdrLen) - randomFill(hdrBuf) - - data := buffer.NewView(dataLen) - randomFill(data) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), - Data: data.ToVectorisedView(), - }) - copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check the ethernet header. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: localLinkAddr, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } - - // Compare contents skipping the ethernet header added by the - // endpoint. - merged := append(hdrBuf, data...) - if uint32(len(contents)) < pi.Size { - t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) - } - contents = contents[:pi.Size][header.EthernetMinimumSize:] - - if !bytes.Equal(contents, merged) { - t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) - } - }() - } -} - -// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress -// set in Route (using much of the same code as TestSimpleSend), then checks -// that the encoded ethernet header received includes the correct SrcAddr. -func TestPreserveSrcAddressInSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) - // Set both remote and local link address in route. - var r stack.RouteInfo - r.LocalLinkAddress = newLocalLinkAddress - r.RemoteLinkAddress = remoteLinkAddr - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - ReserveHeaderBytes: header.EthernetMinimumSize, - }) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check that the ethernet header contains the expected SrcAddr. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: newLocalLinkAddress, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } -} - -// TestFillTxQueue sends packets until the queue is full. -func TestFillTxQueue(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } -} - -// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets -// until the queue is full. -func TestFillTxQueueAfterBadCompletion(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Send a bad completion. - queue.EncodeTxCompletion(c.txq.rx.Push(8), 1) - c.txq.rx.Flush() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Send two packets so that the id slice has at least two slots. - for i := 2; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } - - // Complete the two writes twice. - for i := 2; i > 0; i-- { - pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull()) - - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - c.txq.rx.Flush() - } - c.txq.tx.Flush() - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } -} - -// TestFillTxMemory sends packets until the we run out of shared memory. -func TestFillTxMemory(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible until - // we fill the memory. - ids := make(map[uint64]struct{}) - for i := queueDataSize / bufferSize; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - c.txq.tx.Flush() - } - - // Next attempt to write must fail. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } -} - -// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of -// shared memory for a 2-buffer packet, but still with room for a 1-buffer -// packet. -func TestFillTxMemoryWithMultiBuffer(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible - // until there is only one buffer left. - for i := queueDataSize/bufferSize - 1; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Pull the posted buffer. - c.txq.tx.Pull() - c.txq.tx.Flush() - } - - // Attempt to write a two-buffer packet. It must fail. - { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buffer.NewView(bufferSize).ToVectorisedView(), - }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } - } - - // Attempt to write the one-buffer packet again. It must succeed. - { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } -} - -func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte { - t.Helper() - - for { - b := p.Pull() - if b != nil { - return b - } - - select { - case <-time.After(10 * time.Millisecond): - case <-to: - t.Fatal(errStr) - } - } -} - -// TestSimpleReceive completes 1000 different receives with random payload and -// random number of buffers. It checks that the contents match the expected -// values. -func TestSimpleReceive(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Check that buffers have been posted. - limit := c.ep.rx.q.PostedBuffersLimit() - for i := uint64(0); i < limit; i++ { - timeout := time.After(2 * time.Second) - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted")) - - if want := i * bufferSize; want != bi.Offset { - t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want) - } - - if want := i; want != bi.ID { - t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want) - } - - if bufferSize != bi.Size { - t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize) - } - } - c.rxq.tx.Flush() - - // Create a slice with the indices 0..limit-1. - idx := make([]int, limit) - for i := range idx { - idx[i] = i - } - - // Complete random packets 1000 times. - for iters := 1000; iters > 0; iters-- { - timeout := time.After(2 * time.Second) - // Prepare a random packet. - shuffle(idx) - n := 1 + rand.Intn(10) - bufs := make([]queue.RxBuffer, n) - contents := make([]byte, bufferSize*n-rand.Intn(500)) - randomFill(contents) - for i := range bufs { - j := idx[i] - bufs[i].Size = bufferSize - bufs[i].Offset = uint64(bufferSize * j) - bufs[i].ID = uint64(j) - - copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:]) - } - - // Push completion. - c.pushRxCompletion(uint32(len(contents)), bufs) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be received, then check it. - c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") - c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) - c.packets = c.packets[:0] - c.mu.Unlock() - - if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) { - t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents) - } - - // Check that buffers have been reposted. - for i := range bufs { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted")) - if bi != bufs[i] { - t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i]) - } - } - c.rxq.tx.Flush() - } -} - -// TestRxBuffersReposted tests that rx buffers get reposted after they have been -// completed. -func TestRxBuffersReposted(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Receive all posted buffers. - limit := c.ep.rx.q.PostedBuffersLimit() - buffers := make([]queue.RxBuffer, 0, limit) - for i := limit; i > 0; i-- { - timeout := time.After(2 * time.Second) - buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers"))) - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when individually completed. - for i := range buffers { - timeout := time.After(2 * time.Second) - // Complete the buffer. - c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for it to be reposted. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[i] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i]) - } - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when completed in pairs. - for i := 0; i < len(buffers)/2; i++ { - timeout := time.After(2 * time.Second) - // Complete with two buffers. - c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for them to be reposted. - for j := 0; j < 2; j++ { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[2*i+j] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j]) - } - } - } - c.rxq.tx.Flush() -} - -// TestReceivePostingIsFull checks that the endpoint will properly handle the -// case when a received buffer cannot be immediately reposted because it hasn't -// been pulled from the tx pipe yet. -func TestReceivePostingIsFull(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Complete first posted buffer before flushing it from the tx pipe. - first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) - c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that packet is received. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Complete another buffer. - second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) - c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that no packet is received yet, as the worker is blocked trying - // to repost. - select { - case <-time.After(500 * time.Millisecond): - case <-c.packetCh: - t.Fatalf("Unexpected packet received") - } - - // Flush tx queue, which will allow the first buffer to be reposted, - // and the second completion to be pulled. - c.rxq.tx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that second packet completes. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") -} - -// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to -// repost a buffer. Make sure it backs out. -func TestCloseWhileWaitingToPost(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - cleaned := false - defer func() { - if !cleaned { - c.cleanup() - } - }() - - // Complete first posted buffer before flushing it from the tx pipe. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) - c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be indicated. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Cleanup and wait for worker to complete. - c.cleanup() - cleaned = true - c.ep.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go new file mode 100644 index 000000000..ac3a66520 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sharedmem diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD deleted file mode 100644 index 4aac12a8c..000000000 --- a/pkg/tcpip/link/sniffer/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sniffer", - srcs = [ - "pcap.go", - "sniffer.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go new file mode 100644 index 000000000..8d79defea --- /dev/null +++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sniffer diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD deleted file mode 100644 index 86f14db76..000000000 --- a/pkg/tcpip/link/tun/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "tun_endpoint_refs", - out = "tun_endpoint_refs.go", - package = "tun", - prefix = "tunEndpoint", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "tunEndpoint", - }, -) - -go_library( - name = "tun", - srcs = [ - "device.go", - "protocol.go", - "tun_endpoint_refs.go", - "tun_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sync", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/link/tun/tun_endpoint_refs.go b/pkg/tcpip/link/tun/tun_endpoint_refs.go new file mode 100644 index 000000000..276cbdb20 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_endpoint_refs.go @@ -0,0 +1,132 @@ +package tun + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const tunEndpointenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var tunEndpointobj *tunEndpoint + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// +stateify savable +type tunEndpointRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *tunEndpointRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *tunEndpointRefs) RefType() string { + return fmt.Sprintf("%T", tunEndpointobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *tunEndpointRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *tunEndpointRefs) LogRefs() bool { + return tunEndpointenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *tunEndpointRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *tunEndpointRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if tunEndpointenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.RefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *tunEndpointRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if tunEndpointenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *tunEndpointRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if tunEndpointenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *tunEndpointRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/tcpip/link/tun/tun_state_autogen.go b/pkg/tcpip/link/tun/tun_state_autogen.go new file mode 100644 index 000000000..3515d86fd --- /dev/null +++ b/pkg/tcpip/link/tun/tun_state_autogen.go @@ -0,0 +1,64 @@ +// automatically generated by stateify. + +package tun + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (d *Device) StateTypeName() string { + return "pkg/tcpip/link/tun.Device" +} + +func (d *Device) StateFields() []string { + return []string{ + "Queue", + "endpoint", + "notifyHandle", + "flags", + } +} + +func (d *Device) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.Queue) + stateSinkObject.Save(1, &d.endpoint) + stateSinkObject.Save(2, &d.notifyHandle) + stateSinkObject.Save(3, &d.flags) +} + +func (d *Device) afterLoad() {} + +func (d *Device) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.Queue) + stateSourceObject.Load(1, &d.endpoint) + stateSourceObject.Load(2, &d.notifyHandle) + stateSourceObject.Load(3, &d.flags) +} + +func (r *tunEndpointRefs) StateTypeName() string { + return "pkg/tcpip/link/tun.tunEndpointRefs" +} + +func (r *tunEndpointRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *tunEndpointRefs) beforeSave() {} + +func (r *tunEndpointRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +func (r *tunEndpointRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func init() { + state.Register((*Device)(nil)) + state.Register((*tunEndpointRefs)(nil)) +} diff --git a/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go b/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go new file mode 100644 index 000000000..149299ea3 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go @@ -0,0 +1,5 @@ +// automatically generated by stateify. + +// +build linux + +package tun diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD deleted file mode 100644 index 9b4602c1b..000000000 --- a/pkg/tcpip/link/waitable/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "waitable", - srcs = [ - "waitable.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/gate", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "waitable_test", - srcs = [ - "waitable_test.go", - ], - library = ":waitable", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/waitable/waitable_state_autogen.go b/pkg/tcpip/link/waitable/waitable_state_autogen.go new file mode 100644 index 000000000..059424fa0 --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package waitable diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go deleted file mode 100644 index e368a9eaa..000000000 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waitable - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type countedEndpoint struct { - dispatchCount int - writeCount int - attachCount int - - mtu uint32 - capabilities stack.LinkEndpointCapabilities - hdrLen uint16 - linkAddr tcpip.LinkAddress - - dispatcher stack.NetworkDispatcher -} - -func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dispatchCount++ -} - -func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.attachCount++ - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *countedEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *countedEndpoint) MTU() uint32 { - return e.mtu -} - -func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities -} - -func (e *countedEndpoint) MaxHeaderLength() uint16 { - return e.hdrLen -} - -func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - e.writeCount++ - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - e.writeCount += pkts.Len() - return pkts.Len(), nil -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("unimplemented") -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*countedEndpoint) Wait() {} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestWaitWrite(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Write and check that it goes through. - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 1; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on dispatches, then try to write. It must go through. - wep.WaitDispatch() - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on writes, then try to write. It must not go through. - wep.WaitWrite() - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } -} - -func TestWaitDispatch(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Check that attach happens. - wep.Attach(ep) - if want := 1; ep.attachCount != want { - t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want) - } - - // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 1; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on writes, then try to dispatch. It must go through. - wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on dispatches, then try to dispatch. It must not go through. - wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } -} - -func TestOtherMethods(t *testing.T) { - const ( - mtu = 0xdead - capabilities = 0xbeef - hdrLen = 0x1234 - linkAddr = "test address" - ) - ep := &countedEndpoint{ - mtu: mtu, - capabilities: capabilities, - hdrLen: hdrLen, - linkAddr: linkAddr, - } - wep := New(ep) - - if v := wep.MTU(); v != mtu { - t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) - } - - if v := wep.Capabilities(); v != capabilities { - t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities) - } - - if v := wep.MaxHeaderLength(); v != hdrLen { - t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen) - } - - if v := wep.LinkAddress(); v != linkAddr { - t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr) - } -} diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD deleted file mode 100644 index fa8814bac..000000000 --- a/pkg/tcpip/network/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "ip_test", - size = "small", - srcs = [ - "ip_test.go", - "multicast_group_test.go", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD deleted file mode 100644 index d59d678b2..000000000 --- a/pkg/tcpip/network/arp/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "arp", - srcs = [ - "arp.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "arp_test", - size = "small", - srcs = ["arp_test.go"], - deps = [ - ":arp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) - -go_test( - name = "stats_test", - size = "small", - srcs = ["stats_test.go"], - library = ":arp", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go new file mode 100644 index 000000000..5cd8535e3 --- /dev/null +++ b/pkg/tcpip/network/arp/arp_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package arp diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go deleted file mode 100644 index 018d6a578..000000000 --- a/pkg/tcpip/network/arp/arp_test.go +++ /dev/null @@ -1,711 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp_test - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -const ( - nicID = 1 - - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") - remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - - unknownAddr = tcpip.Address("\x0a\x00\x00\x03") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // eventChanSize defines the size of event channels used by the neighbor - // cache's event dispatcher. The size chosen here needs to be sufficient to - // queue all the events received during tests before consumption. - // If eventChanSize is too small, the tests may deadlock. - eventChanSize = 32 -) - -type eventType uint8 - -const ( - entryAdded eventType = iota - entryChanged - entryRemoved -) - -func (t eventType) String() string { - switch t { - case entryAdded: - return "add" - case entryChanged: - return "change" - case entryRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -type eventInfo struct { - eventType eventType - nicID tcpip.NICID - entry stack.NeighborEntry -} - -func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) -} - -// arpDispatcher implements NUDDispatcher to validate the dispatching of -// events upon certain NUD state machine events. -type arpDispatcher struct { - // C is where events are queued - C chan eventInfo -} - -var _ stack.NUDDispatcher = (*arpDispatcher)(nil) - -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryChanged, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryRemoved, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { - select { - case got := <-d.C: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - case <-ctx.Done(): - return fmt.Errorf("%s for %s", ctx.Err(), want) - } - return nil -} - -func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.waitForEvent(ctx, want) -} - -func (d *arpDispatcher) nextEvent() (eventInfo, bool) { - select { - case event := <-d.C: - return event, true - default: - return eventInfo{}, false - } -} - -type testContext struct { - s *stack.Stack - linkEP *channel.Endpoint - nudDisp *arpDispatcher -} - -func newTestContext(t *testing.T) *testContext { - c := stack.DefaultNUDConfigurations() - // Transition from Reachable to Stale almost immediately to test if receiving - // probes refreshes positive reachability. - c.BaseReachableTime = time.Microsecond - - d := arpDispatcher{ - // Create an event channel large enough so the neighbor cache doesn't block - // while dispatching events. Blocking could interfere with the timing of - // NUD transitions. - C: make(chan eventInfo, eventChanSize), - } - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - NUDConfigs: c, - NUDDisp: &d, - }) - - ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(nicID, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }}) - - return &testContext{ - s: s, - linkEP: ep, - nudDisp: &d, - } -} - -func (c *testContext) cleanup() { - c.linkEP.Close() -} - -func TestMalformedPacket(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - v := make(buffer.View, header.ARPSize) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) - if err != nil { - t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err) - } - ep.Disable() - - v := make(buffer.View, header.ARPSize) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDirectReply(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPReply) - - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - copy(h.HardwareAddressTarget(), stackLinkAddr) - copy(h.ProtocolAddressTarget(), stackAddr) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDirectRequest(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - tests := []struct { - name string - senderAddr tcpip.Address - senderLinkAddr tcpip.LinkAddress - targetAddr tcpip.Address - isValid bool - }{ - { - name: "Loopback", - senderAddr: stackAddr, - senderLinkAddr: stackLinkAddr, - targetAddr: stackAddr, - isValid: true, - }, - { - name: "Remote", - senderAddr: remoteAddr, - senderLinkAddr: remoteLinkAddr, - targetAddr: stackAddr, - isValid: true, - }, - { - name: "RemoteInvalidTarget", - senderAddr: remoteAddr, - senderLinkAddr: remoteLinkAddr, - targetAddr: unknownAddr, - isValid: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - packetsRecv := c.s.Stats().ARP.PacketsReceived.Value() - requestsRecv := c.s.Stats().ARP.RequestsReceived.Value() - requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() - outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value() - - // Inject an incoming ARP request. - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), test.senderLinkAddr) - copy(h.ProtocolAddressSender(), test.senderAddr) - copy(h.ProtocolAddressTarget(), test.targetAddr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - })) - - if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) - } - if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) - } - - if !test.isValid { - // No packets should be sent after receiving an invalid ARP request. - // There is no need to perform a blocking read here, since packets are - // sent in the same function that handles ARP requests. - if pkt, ok := c.linkEP.Read(); ok { - t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) - } - if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want { - t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want) - } - if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) - } - - return - } - - if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) - } - - // Verify an ARP response was sent. - pi, ok := c.linkEP.Read() - if !ok { - t.Fatal("expected ARP response to be sent, got none") - } - - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.NetworkHeader().View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want) - } - - // Verify the sender was saved in the neighbor cache. - wantEvent := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: stack.NeighborEntry{ - Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), - State: stack.Stale, - }, - } - if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { - t.Fatal(err) - } - - neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) - if err != nil { - t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff) - } - t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr) - } - neighborByAddr[n.Addr] = n - } - - neigh, ok := neighborByAddr[test.senderAddr] - if !ok { - t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr) - } - if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { - t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) - } - if got, want := neigh.State, stack.Stale; got != want { - t.Errorf("got neighbor State = %s, want = %s", got, want) - } - - // No more events should be dispatched - for { - event, ok := c.nudDisp.nextEvent() - if !ok { - break - } - t.Errorf("unexpected %s", event) - } - }) - } -} - -var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil) - -type testLinkEndpoint struct { - stack.LinkEndpoint - - writeErr tcpip.Error -} - -func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if t.writeErr != nil { - return t.writeErr - } - - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) -} - -func TestLinkAddressRequest(t *testing.T) { - const nicID = 1 - - testAddr := tcpip.Address([]byte{1, 2, 3, 4}) - - tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - linkErr tcpip.Error - expectedErr tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - expectedRequestsSent uint64 - expectedRequestBadLocalAddressErrors uint64 - expectedRequestInterfaceHasNoLocalAddressErrors uint64 - expectedRequestDroppedErrors uint64 - }{ - { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - localAddr: "", - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - localAddr: "", - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with unassigned address", - nicAddr: stackAddr, - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: &tcpip.ErrBadLocalAddress{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with unassigned address", - nicAddr: stackAddr, - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: &tcpip.ErrBadLocalAddress{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with no local address available", - nicAddr: "", - localAddr: "", - remoteLinkAddr: remoteLinkAddr, - expectedErr: &tcpip.ErrNetworkUnreachable{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 1, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with no local address available", - nicAddr: "", - localAddr: "", - remoteLinkAddr: "", - expectedErr: &tcpip.ErrNetworkUnreachable{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 1, - expectedRequestDroppedErrors: 0, - }, - { - name: "Link error", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - linkErr: &tcpip.ErrInvalidEndpointState{}, - expectedErr: &tcpip.ErrInvalidEndpointState{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - }) - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err) - } - linkRes, ok := ep.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) - } - - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) - } - } - - { - err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) - } - } - - if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { - t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) - } - if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors) - } - if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) - } - if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) - } - - if test.expectedErr != nil { - return - } - - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) - } - - rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) - } - }) - } -} - -func TestDADARPRequestPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - }, - }), ipv4.NewProtocol}, - }) - e := channel.New(1, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if res, err := s.CheckDuplicateAddress(nicID, header.IPv4ProtocolNumber, remoteAddr, func(stack.DADResult) {}); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, header.IPv4ProtocolNumber, remoteAddr, err) - } else if res != stack.DADStarting { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) - } - - pkt, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatal("expected to send an ARP request") - } - - if pkt.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - - req := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if !req.IsValid() { - t.Errorf("got req.IsValid() = false, want = true") - } - if got := req.Op(); got != header.ARPRequest { - t.Errorf("got req.Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(req.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("got req.HardwareAddressSender() = %s, want = %s", got, stackLinkAddr) - } - if got := tcpip.Address(req.ProtocolAddressSender()); got != header.IPv4Any { - t.Errorf("got req.ProtocolAddressSender() = %s, want = %s", got, header.IPv4Any) - } - if got, want := tcpip.LinkAddress(req.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { - t.Errorf("got req.HardwareAddressTarget() = %s, want = %s", got, want) - } - if got := tcpip.Address(req.ProtocolAddressTarget()); got != remoteAddr { - t.Errorf("got req.ProtocolAddressTarget() = %s, want = %s", got, remoteAddr) - } -} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go deleted file mode 100644 index e867b3c3f..000000000 --- a/pkg/tcpip/network/arp/stats_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface - nicID tcpip.NICID -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // expected to be bound by a MultiCounterStat. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD deleted file mode 100644 index 872165866..000000000 --- a/pkg/tcpip/network/hash/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hash", - srcs = ["hash.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/rand", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go new file mode 100644 index 000000000..9467fe298 --- /dev/null +++ b/pkg/tcpip/network/hash/hash_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hash diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD deleted file mode 100644 index 274f09092..000000000 --- a/pkg/tcpip/network/internal/fragmentation/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - visibility = [ - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "fragmentation_test.go", - "reassembler_test.go", - ], - library = ":fragmentation", - deps = [ - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go new file mode 100644 index 000000000..3f82c184a --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go @@ -0,0 +1,64 @@ +// automatically generated by stateify. + +package fragmentation + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *reassemblerList) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerList" +} + +func (l *reassemblerList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *reassemblerList) beforeSave() {} + +func (l *reassemblerList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *reassemblerList) afterLoad() {} + +func (l *reassemblerList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *reassemblerEntry) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry" +} + +func (e *reassemblerEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *reassemblerEntry) beforeSave() {} + +func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *reassemblerEntry) afterLoad() {} + +func (e *reassemblerEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*reassemblerList)(nil)) + state.Register((*reassemblerEntry)(nil)) +} diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go deleted file mode 100644 index 47ea3173e..000000000 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,638 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "errors" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// reassembleTimeout is dummy timeout used for testing, where the clock never -// advances. -const reassembleTimeout = 1 - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -func pkt(size int, pieces ...string) *stack.PacketBuffer { - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv(size, pieces...), - }) -} - -type processInput struct { - id FragmentID - first uint16 - last uint16 - more bool - proto uint8 - pkt *stack.PacketBuffer -} - -type processOutput struct { - vv buffer.VectorisedView - proto uint8 - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Next Header protocol mismatch", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), proto: 6, done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) - firstFragmentProto := c.in[0].proto - for i, in := range c.in { - resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) - if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, err) - } - if done != c.out[i].done { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", - in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) - } - if c.out[i].done { - if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) - } - if firstFragmentProto != proto { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", - in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) - } - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - const ( - reassemblyTimeout = time.Millisecond - protocol = 0xff - ) - - type fragment struct { - first uint16 - last uint16 - more bool - data string - } - - type event struct { - // name is a nickname of this event. - name string - - // clockAdvance is a duration to advance the clock. The clock advances - // before a fragment specified in the fragment field is processed. - clockAdvance time.Duration - - // fragment is a fragment to process. This can be nil if there is no - // fragment to process. - fragment *fragment - - // expectDone is true if the fragmentation instance should report the - // reassembly is done after the fragment is processd. - expectDone bool - - // memSizeAfterEvent is the expected memory size of the fragmentation - // instance after the event. - memSizeAfterEvent int - } - - memSizeOfFrags := func(frags ...*fragment) int { - var size int - for _, frag := range frags { - size += pkt(len(frag.data), frag.data).MemSize() - } - return size - } - - half1 := &fragment{first: 0, last: 0, more: true, data: "0"} - half2 := &fragment{first: 1, last: 1, more: false, data: "1"} - - tests := []struct { - name string - events []event - }{ - { - name: "half1 and half2 are reassembled successfully", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half2", - fragment: half2, - expectDone: true, - memSizeAfterEvent: 0, - }, - }, - }, - { - name: "half1 timeout, half2 timeout", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - { - name: "half2", - fragment: half2, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) - for _, event := range test.events { - clock.Advance(event.clockAdvance) - if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) - if err != nil { - t.Fatalf("%s: f.Process failed: %s", event.name, err) - } - if done != event.expectDone { - t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) - } - } - if got, want := f.memSize, event.memSizeAfterEvent; got != want { - t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) - } - } - }) - } -} - -func TestMemoryLimits(t *testing.T) { - lowLimit := pkt(1, "0").MemSize() - highLimit := 3 * lowLimit // Allow at most 3 such packets. - f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) - // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) - // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) - - if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - memSize := pkt(1, "0").MemSize() - f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - - if got, want := f.memSize, memSize; got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} - -func TestErrors(t *testing.T) { - tests := []struct { - name string - blockSize uint16 - first uint16 - last uint16 - more bool - data string - err error - }{ - { - name: "exact block size without more", - blockSize: 2, - first: 2, - last: 3, - more: false, - data: "01", - }, - { - name: "exact block size with more", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "01", - }, - { - name: "exact block size with more and extra data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "012", - err: ErrInvalidArgs, - }, - { - name: "exact block size with more and too little data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size with more", - blockSize: 2, - first: 2, - last: 2, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size without more", - blockSize: 2, - first: 2, - last: 2, - more: false, - data: "0", - }, - { - name: "first not a multiple of block size", - blockSize: 2, - first: 3, - last: 4, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - { - name: "first more than last", - blockSize: 2, - first: 4, - last: 3, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) - if !errors.Is(err, test.err) { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) - } - if done { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) - } - }) - } -} - -type fragmentInfo struct { - remaining int - copied int - offset int - more bool -} - -func TestPacketFragmenter(t *testing.T) { - const ( - reserve = 60 - proto = 0 - ) - - tests := []struct { - name string - fragmentPayloadLen uint32 - transportHeaderLen int - payloadSize int - wantFragments []fragmentInfo - }{ - { - name: "Packet exactly fits in MTU", - fragmentPayloadLen: 1280, - transportHeaderLen: 0, - payloadSize: 1280, - wantFragments: []fragmentInfo{ - {remaining: 0, copied: 1280, offset: 0, more: false}, - }, - }, - { - name: "Packet exactly does not fit in MTU", - fragmentPayloadLen: 1000, - transportHeaderLen: 0, - payloadSize: 1001, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 1000, offset: 0, more: true}, - {remaining: 0, copied: 1, offset: 1000, more: false}, - }, - }, - { - name: "Packet has a transport header", - fragmentPayloadLen: 560, - transportHeaderLen: 40, - payloadSize: 560, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 560, offset: 0, more: true}, - {remaining: 0, copied: 40, offset: 560, more: false}, - }, - }, - { - name: "Packet has a huge transport header", - fragmentPayloadLen: 500, - transportHeaderLen: 1300, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {remaining: 3, copied: 500, offset: 0, more: true}, - {remaining: 2, copied: 500, offset: 500, more: true}, - {remaining: 1, copied: 500, offset: 1000, more: true}, - {remaining: 0, copied: 300, offset: 1500, more: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) - var originalPayload buffer.VectorisedView - originalPayload.AppendView(pkt.TransportHeader().View()) - originalPayload.Append(pkt.Data) - var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) - for i := 0; ; i++ { - fragPkt, offset, copied, more := pf.BuildNextFragment() - wantFragment := test.wantFragments[i] - if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { - t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) - } - if copied != wantFragment.copied { - t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) - } - if offset != wantFragment.offset { - t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) - } - if more != wantFragment.more { - t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) - } - if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) - } - if got := fragPkt.AvailableHeaderBytes(); got != reserve { - t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) - } - if got := fragPkt.TransportHeader().View().Size(); got != 0 { - t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) - } - reassembledPayload.Append(fragPkt.Data) - if !more { - if i != len(test.wantFragments)-1 { - t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) - } - break - } - } - if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { - t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type testTimeoutHandler struct { - pkt *stack.PacketBuffer -} - -func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { - h.pkt = pkt -} - -func TestTimeoutHandler(t *testing.T) { - const ( - proto = 99 - ) - - pk1 := pkt(1, "1") - pk2 := pkt(1, "2") - - type processParam struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - } - - tests := []struct { - name string - params []processParam - wantError bool - wantPkt *stack.PacketBuffer - }{ - { - name: "onTimeout runs", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "no first fragment", - params: []processParam{ - { - first: 1, - last: 1, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: nil, - }, - { - name: "second pkt is ignored", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - { - first: 0, - last: 0, - more: true, - pkt: pk2, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "invalid args - first is greater than last", - params: []processParam{ - { - first: 1, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: true, - wantPkt: nil, - }, - } - - id := FragmentID{ID: 0} - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - handler := &testTimeoutHandler{pkt: nil} - - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - - for _, p := range test.params { - if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { - t.Errorf("f.Process error = %s", err) - } - } - if !test.wantError { - r, ok := f.reassemblers[id] - if !ok { - t.Fatal("Reassembler not found") - } - f.release(r, true) - } - switch { - case handler.pkt != nil && test.wantPkt == nil: - t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) - case handler.pkt == nil && test.wantPkt != nil: - t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) - case handler.pkt != nil && test.wantPkt != nil: - if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { - t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_list.go b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go new file mode 100644 index 000000000..673bb11b0 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go @@ -0,0 +1,221 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *reassemblerList) Len() (count int) { + for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *reassemblerList) PushFront(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *reassemblerList) PushBack(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + bLinker := reassemblerElementMapper{}.linkerFor(b) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + aLinker := reassemblerElementMapper{}.linkerFor(a) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *reassemblerList) Remove(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go deleted file mode 100644 index 214a93709..000000000 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "bytes" - "math" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type processParams struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - wantDone bool - wantError error -} - -func TestReassemblerProcess(t *testing.T) { - const proto = 99 - - v := func(size int) buffer.View { - payload := buffer.NewView(size) - for i := 1; i < size; i++ { - payload[i] = uint8(i) * 3 - } - return payload - } - - pkt := func(sizes ...int) *stack.PacketBuffer { - var vv buffer.VectorisedView - for _, size := range sizes { - vv.AppendView(v(size)) - } - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - } - - var tests = []struct { - name string - params []processParams - want []hole - wantPkt *stack.PacketBuffer - }{ - { - name: "No fragments", - params: nil, - want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, - }, - { - name: "One fragment at beginning", - params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, - {first: 2, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment in the middle", - params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, - {first: 0, last: 0, filled: false, final: false}, - {first: 3, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment at the end", - params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, - {first: 0, last: 0, filled: false}, - }, - }, - { - name: "One fragment completing a packet", - params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: true}, - }, - wantPkt: pkt(2), - }, - { - name: "Two fragments completing a packet", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a duplicate", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a partial duplicate", - params: []processParams{ - {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, - {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 3, filled: true, final: false}, - {first: 4, last: 5, filled: true, final: true}, - }, - wantPkt: pkt(4, 2), - }, - { - name: "Two overlapping fragments", - params: []processParams{ - {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, - {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, - }, - want: []hole{ - {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, - {first: 11, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "Two final fragments with different ends", - params: []processParams{ - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, - {first: 0, last: 9, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate, with different ends", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - var resPkt *stack.PacketBuffer - var isDone bool - for _, param := range test.params { - pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) - if done != param.wantDone || err != param.wantError { - t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) - } - if done { - resPkt = pkt - isDone = true - } - } - - ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } - cmpPktData := func(a, b *stack.PacketBuffer) bool { - if a == nil || b == nil { - return a == b - } - return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) - } - - if isDone { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - // Do not compare pkt in hole. Data will be altered. - cmp.Comparer(ignorePkt), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { - t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) - } - } else { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - cmp.Comparer(cmpPktData), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD deleted file mode 100644 index d21b4c7ef..000000000 --- a/pkg/tcpip/network/internal/ip/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ip", - srcs = [ - "duplicate_address_detection.go", - "generic_multicast_protocol.go", - "stats.go", - ], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ip_x_test", - size = "small", - srcs = [ - "duplicate_address_detection_test.go", - "generic_multicast_protocol_test.go", - ], - deps = [ - ":ip", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/faketime", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go deleted file mode 100644 index 18c357b56..000000000 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type mockDADProtocol struct { - t *testing.T - - mu struct { - sync.Mutex - - dad ip.DAD - sendCount map[tcpip.Address]int - } -} - -func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) { - m.mu.Lock() - defer m.mu.Unlock() - - m.t = t - opts.Protocol = m - m.mu.dad.Init(&m.mu, c, opts) - m.initLocked() -} - -func (m *mockDADProtocol) initLocked() { - m.mu.sendCount = make(map[tcpip.Address]int) -} - -func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.sendCount[addr]++ - return nil -} - -func (m *mockDADProtocol) check(addrs []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendCount := make(map[tcpip.Address]int) - for _, a := range addrs { - sendCount[a]++ - } - - diff := cmp.Diff(sendCount, m.mu.sendCount) - m.initLocked() - return diff -} - -func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.dad.CheckDuplicateAddressLocked(addr, h) -} - -func (m *mockDADProtocol) stop(addr tcpip.Address, aborted bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.StopLocked(addr, aborted) -} - -func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.SetConfigsLocked(c) -} - -const ( - addr1 = tcpip.Address("\x01") - addr2 = tcpip.Address("\x02") - addr3 = tcpip.Address("\x03") - addr4 = tcpip.Address("\x04") -) - -type dadResult struct { - Addr tcpip.Address - R stack.DADResult -} - -func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) { - return func(r stack.DADResult) { - ch <- dadResult{Addr: a, R: r} - } -} - -func TestDADCheckDuplicateAddress(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dad.init(t, stack.DADConfigurations{}, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 2) - - // DAD should initially be disabled. - if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) - } - // Wait for any initially fired timers to complete. - clock.Advance(0) - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Enable and request DAD. - dadConfigs1 := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs1) - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - // The second request for DAD on the same address should use the original - // request since it has not completed yet. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) - } - clock.Advance(0) - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dadConfigs2 := stack.DADConfigurations{ - DupAddrDetectTransmits: 2, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs2) - // A new address should start a new DAD process. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Make sure DAD for addr1 only resolves after the expected timeout. - const delta = time.Nanosecond - dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer - clock.Advance(dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r) - default: - } - clock.Advance(delta) - for i := 0; i < 2; i++ { - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { - t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff) - } - } - - // Make sure DAD for addr2 only resolves after the expected timeout. - dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer - clock.Advance(dadConfig2Duration - dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r) - default: - } - clock.Advance(delta) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for addr2 after it resolved. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadConfig2Duration) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore results. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} - -func TestDADStop(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.init(t, dadConfigs, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 1) - - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr1, true /* aborted */) - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: false, Err: &tcpip.ErrAborted{}}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr2, false /* aborted */) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: false, Err: nil}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr3, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for an address we stopped DAD on. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore updates. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go deleted file mode 100644 index 381460c82..000000000 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ /dev/null @@ -1,805 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" -) - -const maxUnsolicitedReportDelay = time.Second - -var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) - -type mockMulticastGroupProtocolProtectedFields struct { - sync.RWMutex - - genericMulticastGroup ip.GenericMulticastProtocolState - sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddrCount map[tcpip.Address]int - makeQueuePackets bool - disabled bool -} - -type mockMulticastGroupProtocol struct { - t *testing.T - - mu mockMulticastGroupProtocolProtectedFields -} - -func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { - m.mu.Lock() - defer m.mu.Unlock() - m.initLocked() - opts.Protocol = m - m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) -} - -func (m *mockMulticastGroupProtocol) initLocked() { - m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) -} - -func (m *mockMulticastGroupProtocol) setEnabled(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.disabled = !v -} - -func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.makeQueuePackets = v -} - -func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.JoinGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleReportLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) -} - -func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) -} - -func (m *mockMulticastGroupProtocol) makeAllNonMember() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.MakeAllNonMemberLocked() -} - -func (m *mockMulticastGroupProtocol) initializeGroups() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.InitializeGroupsLocked() -} - -func (m *mockMulticastGroupProtocol) sendQueuedReports() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.SendQueuedReportsLocked() -} - -// Enabled implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be read locked. -func (m *mockMulticastGroupProtocol) Enabled() bool { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") - } - - return !m.mu.disabled -} - -// SendReport implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - - m.mu.sendReportGroupAddrCount[groupAddress]++ - return !m.mu.makeQueuePackets, nil -} - -// SendLeave implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - - m.mu.sendLeaveGroupAddrCount[groupAddress]++ - return nil -} - -func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendReportGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendReportGroupAddresses { - sendReportGroupAddrCount[a] = 1 - } - - sendLeaveGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddrCount[a] = 1 - } - - diff := cmp.Diff( - &mockMulticastGroupProtocol{ - mu: mockMulticastGroupProtocolProtectedFields{ - sendReportGroupAddrCount: sendReportGroupAddrCount, - sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, - }, - }, - m, - cmp.AllowUnexported(mockMulticastGroupProtocol{}), - cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), - // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t - cmp.FilterPath( - func(p cmp.Path) bool { - switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": - return true - } - return false - }, - cmp.Ignore(), - ), - ) - m.initLocked() - return diff -} - -func TestJoinGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendReports bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendReports: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendReports: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(0)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, - }) - - // Joining a group should send a report immediately and another after - // a random interval between 0 and the maximum unsolicited report delay. - mgp.joinGroup(test.addr) - if test.shouldSendReports { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestLeaveGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendMessages bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendMessages: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendMessages: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(1)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, - }) - - mgp.joinGroup(test.addr) - if test.shouldSendMessages { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Leaving a group should send a leave report immediately and cancel any - // delayed reports. - { - - if !mgp.leaveGroup(test.addr) { - t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) - } - } - if test.shouldSendMessages { - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleReport(t *testing.T) { - tests := []struct { - name string - reportAddr tcpip.Address - expectReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - reportAddr: "", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Unpecified any", - reportAddr: "\x00", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified", - reportAddr: addr1, - expectReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - reportAddr: addr3, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - reportAddr: addr4, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(2)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report for a group we have a timer scheduled for should - // cancel our delayed report timer for the group. - mgp.handleReport(test.reportAddr) - if len(test.expectReportsFor) != 0 { - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleQuery(t *testing.T) { - tests := []struct { - name string - queryAddr tcpip.Address - maxDelay time.Duration - expectQueriedReportsFor []tcpip.Address - expectDelayedReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - queryAddr: "", - maxDelay: 0, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Unpecified any", - queryAddr: "\x00", - maxDelay: 1, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Specified", - queryAddr: addr1, - maxDelay: 2, - expectQueriedReportsFor: []tcpip.Address{addr1}, - expectDelayedReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - queryAddr: addr3, - maxDelay: 3, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - queryAddr: addr4, - maxDelay: 4, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should make us reschedule our delayed report timer - // to some time within the new max response delay. - mgp.handleQuery(test.queryAddr, test.maxDelay) - clock.Advance(test.maxDelay) - if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The groups that were not affected by the query should still send a - // report after the max unsolicited report delay. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestJoinCount(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: time.Second, - }) - - // Set the join count to 2 for a group. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // Only the first join should trigger a report to be sent. - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should still be considered joined after leaving once. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // A leave report should only be sent once the join count reaches 0. - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Leaving once more should actually remove us from the group. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should no longer be joined so we should not have anything to - // leave. - if mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should send the leave reports for each but still consider them locally - // joined. - mgp.makeAllNonMember() - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !mgp.isLocallyJoined(group) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) - } - } - - // Should send the initial set of unsolcited reports. - mgp.initializeGroups() - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -// TestGroupStateNonMember tests that groups do not send packets when in the -// non-member state, but are still considered locally joined. -func TestGroupStateNonMember(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - mgp.setEnabled(false) - - // Joining groups should not send any reports. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should not send any reports. - mgp.handleQuery(addr1, time.Nanosecond) - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Nanosecond) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Leaving groups should not send any leave messages. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestQueuedPackets(t *testing.T) { - clock := faketime.NewManualClock() - mgp := mockMulticastGroupProtocol{t: t} - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - // Joining should trigger a SendReport, but mgp should report that we did not - // send the packet. - mgp.setQueuePackets(true) - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report timer should have been cancelled since we did not send - // the initial report earlier. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send (we should be idle). - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query but mock being unable to send reports again. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to send reports again - we should have a packet queued to - // send. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query again, but mock being unable to send reports. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report should should transition us into the idle member state, - // even if we had a packet queued. We should no longer have any packets to - // send. - mgp.handleReport(addr1) - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // When we fail to send the initial set of reports, incoming reports should - // not affect a newly joined group's reports from being sent. - mgp.setQueuePackets(true) - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.handleReport(addr2) - // Attempting to send queued reports while still unable to send reports should - // not change the host state. - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} diff --git a/pkg/tcpip/network/internal/ip/ip_state_autogen.go b/pkg/tcpip/network/internal/ip/ip_state_autogen.go new file mode 100644 index 000000000..aee77044e --- /dev/null +++ b/pkg/tcpip/network/internal/ip/ip_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ip diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD deleted file mode 100644 index 1c4f583c7..000000000 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/internal/fragmentation:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go deleted file mode 100644 index f5fa77b65..000000000 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil defines types and functions used to test Network Layer -// functionality such as IP fragmentation. -package testutil - -import ( - "fmt" - "math/rand" - "reflect" - "strings" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// MockLinkEndpoint is an endpoint used for testing, it stores packets written -// to it and can mock errors. -type MockLinkEndpoint struct { - // WrittenPackets is where packets written to the endpoint are stored. - WrittenPackets []*stack.PacketBuffer - - mtu uint32 - err tcpip.Error - allowPackets int -} - -// NewMockLinkEndpoint creates a new MockLinkEndpoint. -// -// err is the error that will be returned once allowPackets packets are written -// to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { - return &MockLinkEndpoint{ - mtu: mtu, - err: err, - allowPackets: allowPackets, - } -} - -// MTU implements LinkEndpoint.MTU. -func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } - -// Capabilities implements LinkEndpoint.Capabilities. -func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - var n int - - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { - return n, err - } - n++ - } - - return n, nil -} - -// Attach implements LinkEndpoint.Attach. -func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*MockLinkEndpoint) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*MockLinkEndpoint) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - -// MakeRandPkt generates a randomized packet. transportHeaderLength indicates -// how many random bytes will be copied in the Transport Header. -// extraHeaderReserveLength indicates how much extra space will be reserved for -// the other headers. The payload is made from Views of the sizes listed in -// viewSizes. -func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { - var views buffer.VectorisedView - - for _, s := range viewSizes { - newView := buffer.NewView(s) - if _, err := rand.Read(newView); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - views.AppendView(newView) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, - Data: views, - }) - pkt.NetworkProtocolNumber = proto - if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - return pkt -} - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go deleted file mode 100644 index 5ff764800..000000000 --- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "reflect" - "unsafe" -) - -// unsafeExposeUnexportedFields takes a Value and returns a version of it in -// which even unexported fields can be read and written. -func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { - return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go deleted file mode 100644 index 90236ed9e..000000000 --- a/pkg/tcpip/network/ip_test.go +++ /dev/null @@ -1,1929 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") - remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") - ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") - ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") - ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") - localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") - ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") - ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - nicID = 1 -) - -var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: localIPv4Addr, - PrefixLen: 24, -} - -var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: localIPv6Addr, - PrefixLen: 120, -} - -type transportError struct { - origin tcpip.SockErrOrigin - typ uint8 - code uint8 - info uint32 - kind stack.TransportErrorKind -} - -// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. -// The former is used to pretend that it's a link endpoint so that we can -// inspect packets written by the network endpoints. The latter is used to -// pretend that it's the network stack so that it can inspect incoming packets -// that have been handled by the network endpoints. -// -// Packets are checked by comparing their fields/values against the expected -// values stored in the test object itself. -type testObject struct { - t *testing.T - protocol tcpip.TransportProtocolNumber - contents []byte - srcAddr tcpip.Address - dstAddr tcpip.Address - v4 bool - transErr transportError - - dataCalls int - controlCalls int -} - -// checkValues verifies that the transport protocol, data contents, src & dst -// addresses of a packet match what's expected. If any field doesn't match, the -// test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { - v := vv.ToView() - if protocol != t.protocol { - t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) - } - - if srcAddr != t.srcAddr { - t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) - } - - if dstAddr != t.dstAddr { - t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) - } - - if len(v) != len(t.contents) { - t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) - } - - for i := range t.contents { - if t.contents[i] != v[i] { - t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) - } - } -} - -// DeliverTransportPacket is called by network endpoints after parsing incoming -// packets. This is used by the test object to verify that the results of the -// parsing are expected. -func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - netHdr := pkt.Network() - t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) - t.dataCalls++ - return stack.TransportPacketHandled -} - -// DeliverTransportError is called by network endpoints after parsing -// incoming control (ICMP) packets. This is used by the test object to verify -// that the results of the parsing are expected. -func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { - t.checkValues(trans, pkt.Data, remote, local) - if diff := cmp.Diff( - t.transErr, - transportError{ - origin: transErr.Origin(), - typ: transErr.Type(), - code: transErr.Code(), - info: transErr.Info(), - kind: transErr.Kind(), - }, - cmp.AllowUnexported(transportError{}), - ); diff != "" { - t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) - } - t.controlCalls++ -} - -// Attach is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (*testObject) IsAttached() bool { - return true -} - -// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that -// matches the linux loopback MTU. -func (*testObject) MTU() uint32 { - return 65536 -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*testObject) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (*testObject) LinkAddress() tcpip.LinkAddress { - return "" -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*testObject) Wait() {} - -// WritePacket is called by network endpoints after producing a packet and -// writing it to the link endpoint. This is used by the test object to verify -// that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - var prot tcpip.TransportProtocolNumber - var srcAddr tcpip.Address - var dstAddr tcpip.Address - - if t.v4 { - h := header.IPv4(pkt.NetworkHeader().View()) - prot = tcpip.TransportProtocolNumber(h.Protocol()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - - } else { - h := header.IPv6(pkt.NetworkHeader().View()) - prot = tcpip.TransportProtocolNumber(h.NextHeader()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - } - t.checkValues(prot, pkt.Data, srcAddr, dstAddr) - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*testObject) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("not implemented") -} - -func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv4.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - Gateway: ipv4Gateway, - NIC: 1, - }}) - - return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) -} - -func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv6.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: ipv6Gateway, - NIC: 1, - }}) - - return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) -} - -func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - e := channel.New(1, mtu, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) - } - - v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) - } - - return s, e -} - -func buildDummyStack(t *testing.T) *stack.Stack { - t.Helper() - - s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) - return s -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - testObject - - mu struct { - sync.RWMutex - disabled bool - } -} - -func (*testInterface) ID() tcpip.NICID { - return nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (t *testInterface) Enabled() bool { - t.mu.RLock() - defer t.mu.RUnlock() - return !t.mu.disabled -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (*testInterface) Spoofing() bool { - return false -} - -func (t *testInterface) setEnabled(v bool) { - t.mu.Lock() - defer t.mu.Unlock() - t.mu.disabled = !v -} - -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { - return nil -} - -func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { - return nil -} - -func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { - return false -} - -func TestSourceAddressValidation(t *testing.T) { - rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - tests := []struct { - name string - srcAddress tcpip.Address - rxICMP func(*channel.Endpoint, tcpip.Address) - valid bool - }{ - { - name: "IPv4 valid", - srcAddress: "\x01\x02\x03\x04", - rxICMP: rxIPv4ICMP, - valid: true, - }, - { - name: "IPv6 valid", - srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", - rxICMP: rxIPv6ICMP, - valid: true, - }, - { - name: "IPv4 unspecified", - srcAddress: header.IPv4Any, - rxICMP: rxIPv4ICMP, - valid: true, - }, - { - name: "IPv6 unspecified", - srcAddress: header.IPv4Any, - rxICMP: rxIPv6ICMP, - valid: true, - }, - { - name: "IPv4 multicast", - srcAddress: "\xe0\x00\x00\x01", - rxICMP: rxIPv4ICMP, - valid: false, - }, - { - name: "IPv6 multicast", - srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - rxICMP: rxIPv6ICMP, - valid: false, - }, - { - name: "IPv4 broadcast", - srcAddress: header.IPv4Broadcast, - rxICMP: rxIPv4ICMP, - valid: false, - }, - { - name: "IPv4 subnet broadcast", - srcAddress: func() tcpip.Address { - subnet := localIPv4AddrWithPrefix.Subnet() - return subnet.Broadcast() - }(), - rxICMP: rxIPv4ICMP, - valid: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) - test.rxICMP(e, test.srcAddress) - - var wantValid uint64 - if test.valid { - wantValid = 1 - } - - if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { - t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) - } - if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { - t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) - } - }) - } -} - -func TestEnableWhenNICDisabled(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - }{ - { - name: "IPv4", - protocolFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - }, - { - name: "IPv6", - protocolFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var nic testInterface - nic.setEnabled(false) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, - }) - p := s.NetworkProtocolInstance(test.protoNum) - - // We pass nil for all parameters except the NetworkInterface and Stack - // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil) - - // The endpoint should initially be disabled, regardless the NIC's enabled - // status. - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - nic.setEnabled(true) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Attempting to enable the endpoint while the NIC is disabled should - // fail. - nic.setEnabled(false) - err := ep.Enable() - if _, ok := err.(*tcpip.ErrNotPermitted); !ok { - t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) - } - // ep should consider the NIC's enabled status when determining its own - // enabled status so we "enable" the NIC to read just the endpoint's - // enabled status. - nic.setEnabled(true) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Enabling the interface after the NIC has been enabled should succeed. - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - if !ep.Enabled() { - t.Fatal("got ep.Enabled() = false, want = true") - } - - // ep should consider the NIC's enabled status when determining its own - // enabled status. - nic.setEnabled(false) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Disabling the endpoint when the NIC is enabled should make the endpoint - // disabled. - nic.setEnabled(true) - ep.Disable() - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - }) - } -} - -func TestIPv4Send(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, - }, - } - ep := proto.NewEndpoint(&nic, nil) - defer ep.Close() - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Setup the packet buffer. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ep.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - }) - - // Issue the write. - nic.testObject.protocol = 123 - nic.testObject.srcAddr = localIPv4Addr - nic.testObject.dstAddr = remoteIPv4Addr - nic.testObject.contents = payload - - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ - Protocol: 123, - TTL: 123, - TOS: stack.DefaultTOS, - }, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestReceive(t *testing.T) { - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - v4 bool - epAddr tcpip.AddressWithPrefix - handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - v4: true, - epAddr: localIPv4Addr.WithPrefix(), - handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { - const totalLen = header.IPv4MinimumSize + 30 /* payload length */ - - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - TTL: ipv4.DefaultTTL, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - }, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - v4: false, - epAddr: localIPv6Addr.WithPrefix(), - handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { - const payloadLen = 30 - view := buffer.NewView(header.IPv6MinimumSize + payloadLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadLen, - TransportProtocol: 10, - HopLimit: ipv6.DefaultTTL, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, - }) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: test.v4, - }, - } - ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) - } - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) - } else { - ep.DecRef() - } - - stat := s.Stats().IP.PacketsReceived - if got := stat.Value(); got != 0 { - t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) - } - test.handlePacket(t, ep, &nic) - if nic.testObject.dataCalls != 1 { - t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } - if got := stat.Value(); got != 1 { - t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) - } - }) - } -} - -func TestIPv4ReceiveControl(t *testing.T) { - const ( - mtu = 0xbeef - header.IPv4MinimumSize - dataLen = 8 - ) - - cases := []struct { - name string - expectedCount int - fragmentOffset uint16 - code header.ICMPv4Code - transErr transportError - trunc int - }{ - { - name: "FragmentationNeeded", - expectedCount: 1, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4FragmentationNeeded), - info: mtu, - kind: stack.PacketTooBigTransportError, - }, - trunc: 0, - }, - { - name: "Truncated (missing IPv4 header)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, - }, - { - name: "Truncated (partial offending packet's IP header)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, - }, - { - name: "Truncated (partial offending packet's data)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, - }, - { - name: "Port unreachable", - expectedCount: 1, - fragmentOffset: 0, - code: header.ICMPv4PortUnreachable, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4PortUnreachable), - kind: stack.DestinationPortUnreachableTransportError, - }, - trunc: 0, - }, - { - name: "Non-zero fragment offset", - expectedCount: 0, - fragmentOffset: 100, - code: header.ICMPv4PortUnreachable, - trunc: 0, - }, - { - name: "Zero-length packet", - expectedCount: 0, - fragmentOffset: 100, - code: header.ICMPv4PortUnreachable, - trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + dataLen) - - // Create the outer IPv4 header. - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(view) - c.trunc), - TTL: 20, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Create the ICMP header. - icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) - icmp.SetType(header.ICMPv4DstUnreachable) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) - ip.Encode(&header.IPv4Fields{ - TotalLength: 100, - TTL: 20, - Protocol: 10, - FragmentOffset: c.fragmentOffset, - SrcAddr: localIPv4Addr, - DstAddr: remoteIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - icmp.SetChecksum(0) - checksum := ^header.Checksum(icmp, 0 /* initial */) - icmp.SetChecksum(checksum) - - // Give packet to IPv4 endpoint, dispatcher will validate that - // it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[dataOffset:] - nic.testObject.transErr = c.transErr - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) - ep.HandlePacket(pkt) - if want := c.expectedCount; nic.testObject.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) - } - }) - } -} - -func TestIPv4FragmentationReceive(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv4MinimumSize + 24 - - frag1 := buffer.NewView(totalLen) - ip1 := header.IPv4(frag1) - ip1.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 0, - Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip1.SetChecksum(^ip1.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag1[i] = uint8(i) - } - - frag2 := buffer.NewView(totalLen) - ip2 := header.IPv4(frag2) - ip2.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 24, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip2.SetChecksum(^ip2.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag2[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - - // Send first segment. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: frag1.ToVectorisedView(), - }) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) - } - - // Send second segment. - pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: frag2.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } -} - -func TestIPv6Send(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Setup the packet buffer. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ep.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - }) - - // Issue the write. - nic.testObject.protocol = 123 - nic.testObject.srcAddr = localIPv6Addr - nic.testObject.dstAddr = remoteIPv6Addr - nic.testObject.contents = payload - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ - Protocol: 123, - TTL: 123, - TOS: stack.DefaultTOS, - }, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv6ReceiveControl(t *testing.T) { - const ( - mtu = 0xffff - outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" - dataLen = 8 - ) - - newUint16 := func(v uint16) *uint16 { return &v } - - portUnreachableTransErr := transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6DstUnreachable), - code: uint8(header.ICMPv6PortUnreachable), - kind: stack.DestinationPortUnreachableTransportError, - } - - cases := []struct { - name string - expectedCount int - fragmentOffset *uint16 - typ header.ICMPv6Type - code header.ICMPv6Code - transErr transportError - trunc int - }{ - { - name: "PacketTooBig", - expectedCount: 1, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6PacketTooBig), - code: uint8(header.ICMPv6UnusedCode), - info: mtu, - kind: stack.PacketTooBigTransportError, - }, - trunc: 0, - }, - { - name: "Truncated (missing offending packet's IPv6 header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, - }, - { - name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, - }, - { - name: "Truncated (partial offending packet's data)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, - }, - { - name: "Port unreachable", - expectedCount: 1, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "Truncated DstPortUnreachable (partial offending packet's IP header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, - }, - { - name: "DstPortUnreachable for Fragmented, zero offset", - expectedCount: 1, - fragmentOffset: newUint16(0), - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "DstPortUnreachable for Non-zero fragment offset", - expectedCount: 0, - fragmentOffset: newUint16(100), - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "Zero-length packet", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize - if c.fragmentOffset != nil { - dataOffset += header.IPv6FragmentHeaderSize - } - view := buffer.NewView(dataOffset + dataLen) - - // Create the outer IPv6 header. - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) - icmp.SetType(c.typ) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - var extHdrs header.IPv6ExtHdrSerializer - // Build the fragmentation header if needed. - if c.fragmentOffset != nil { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: *c.fragmentOffset, - M: true, - Identification: 0x12345678, - }) - } - - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - ExtensionHeaders: extHdrs, - }) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv6 endpoint, dispatcher will validate that - // it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[dataOffset:] - nic.testObject.transErr = c.transErr - - // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv6Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) - ep.HandlePacket(pkt) - if want := c.expectedCount; nic.testObject.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) - } - }) - } -} - -// truncatedPacket returns a PacketBuffer based on a truncated view. If view, -// after truncation, is large enough to hold a network header, it makes part of -// view the packet's NetworkHeader and the rest its Data. Otherwise all of view -// becomes Data. -func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { - v := view[:len(view)-trunc] - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - return pkt -} - -func TestWriteHeaderIncludedPacket(t *testing.T) { - const ( - nicID = 1 - transportProto = 5 - - dataLen = 4 - ) - - dataBuf := [dataLen]byte{1, 2, 3, 4} - data := dataBuf[:] - - ipv4Options := header.IPv4OptionsSerializer{ - &header.IPv4SerializableListEndOption{}, - &header.IPv4SerializableNOPOption{}, - &header.IPv4SerializableListEndOption{}, - &header.IPv4SerializableNOPOption{}, - } - - expectOptions := header.IPv4Options{ - byte(header.IPv4OptionListEndType), - byte(header.IPv4OptionNOPType), - byte(header.IPv4OptionListEndType), - byte(header.IPv4OptionNOPType), - } - - ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} - ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] - - var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte - ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] - if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) - } - if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - nicAddr tcpip.Address - remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView - checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) - expectedErr tcpip.Error - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv4MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv4MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv4 with IHL too small", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv4MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - ip.SetHeaderLength(header.IPv4MinimumSize - 1) - return hdr.View().ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - { - name: "IPv4 too small", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip[:len(ip)-1]).ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - { - name: "IPv4 minimum size", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip).ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv4MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(header.IPv4MinimumSize), - checker.IPPayload(nil), - ) - }, - }, - { - name: "IPv4 with options", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - totalLen := ipHdrLen + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(ipHdrLen)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - Options: ipv4Options, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - if len(netHdr.View()) != hdrLen { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(hdrLen), - checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(expectOptions), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv4 with options and data across views", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - Options: ipv4Options, - }) - vv := buffer.View(ip).ToVectorisedView() - vv.AppendView(data) - return vv - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - if len(netHdr.View()) != hdrLen { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(hdrLen), - checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(expectOptions), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv6MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv6MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv6 with extension header", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - // NB: we're lying about transport protocol here to verify the raw - // fragment header bytes. - TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), - checker.IPPayload(ipv6PayloadWithExtHdr), - ) - }, - }, - { - name: "IPv6 minimum size", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip).ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv6MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(header.IPv6MinimumSize), - checker.IPPayload(nil), - ) - }, - }, - { - name: "IPv6 too small", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip[:len(ip)-1]).ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - subTests := []struct { - name string - srcAddr tcpip.Address - }{ - { - name: "unspecified source", - srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), - }, - { - name: "random source", - srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), - }, - } - - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) - - r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) - } - defer r.Release() - - { - err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr), - })) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) - } - } - - if test.expectedErr != nil { - return - } - - pkt, ok := e.Read() - if !ok { - t.Fatal("expected a packet to be written") - } - test.checker(t, pkt.Pkt, subTest.srcAddr) - }) - } - }) - } -} - -// Test that the included data in an ICMP error packet conforms to the -// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 -func TestICMPInclusionSize(t *testing.T) { - const ( - replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize - replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize - targetSize4 = header.IPv4MinimumProcessableDatagramSize - targetSize6 = header.IPv6MinimumMTU - // A protocol number that will cause an error response. - reservedProtocol = 254 - ) - - // IPv4 function to create a IP packet and send it to the stack. - // The packet should generate an error response. We can do that by using an - // unknown transport protocol (254). - rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { - totalLen := header.IPv4MinimumSize + len(payload) - hdr := buffer.NewPrependable(header.IPv4MinimumSize) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: reservedProtocol, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - vv := hdr.View().ToVectorisedView() - vv.AppendView(buffer.View(payload)) - // Take a copy before InjectInbound takes ownership of vv - // as vv may be changed during the call. - v := vv.ToView() - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return v - } - - // IPv6 function to create a packet and send it to the stack. - // The packet should be errant in a way that causes the stack to send an - // ICMP error response and have enough data to allow the testing of the - // inclusion of the errant packet. Use `unknown next header' to generate - // the error. - rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload)), - TransportProtocol: reservedProtocol, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, - }) - vv := hdr.View().ToVectorisedView() - vv.AppendView(buffer.View(payload)) - // Take a copy before InjectInbound takes ownership of vv - // as vv may be changed during the call. - v := vv.ToView() - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return v - } - - v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { - // We already know the entire packet is the right size so we can use its - // length to calculate the right payload size to check. - expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize - checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()), - checker.SrcAddr(localIPv4Addr), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), - checker.ICMPv4Payload(payload[:expectedPayloadLength]), - ), - ) - } - - v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { - // We already know the entire packet is the right size so we can use its - // length to calculate the right payload size to check. - expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize - checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()), - checker.SrcAddr(localIPv6Addr), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6ParamProblem), - checker.ICMPv6Code(header.ICMPv6UnknownHeader), - checker.ICMPv6Payload(payload[:expectedPayloadLength]), - ), - ) - } - tests := []struct { - name string - srcAddress tcpip.Address - injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View - checker func(*testing.T, *stack.PacketBuffer, buffer.View) - payloadLength int // Not including IP header. - linkMTU uint32 // Largest IP packet that the link can send as payload. - replyLength int // Total size of IP/ICMP packet expected back. - }{ - { - name: "IPv4 exact match", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 - replyHeaderLength4, - linkMTU: targetSize4, - replyLength: targetSize4, - }, - { - name: "IPv4 larger MTU", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4, - linkMTU: targetSize4 + 1000, - replyLength: targetSize4, - }, - { - name: "IPv4 smaller MTU", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4, - linkMTU: targetSize4 - 50, - replyLength: targetSize4 - 50, - }, - { - name: "IPv4 payload exceeds", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 + 10, - linkMTU: targetSize4, - replyLength: targetSize4, - }, - { - name: "IPv4 1 byte less", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 - replyHeaderLength4 - 1, - linkMTU: targetSize4, - replyLength: targetSize4 - 1, - }, - { - name: "IPv4 No payload", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: 0, - linkMTU: targetSize4, - replyLength: replyHeaderLength4, - }, - { - name: "IPv6 exact match", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6 - replyHeaderLength6, - linkMTU: targetSize6, - replyLength: targetSize6, - }, - { - name: "IPv6 larger MTU", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6, - linkMTU: targetSize6 + 400, - replyLength: targetSize6, - }, - // NB. No "smaller MTU" test here as less than 1280 is not permitted - // in IPv6. - { - name: "IPv6 payload exceeds", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6, - linkMTU: targetSize6, - replyLength: targetSize6, - }, - { - name: "IPv6 1 byte less", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6 - replyHeaderLength6 - 1, - linkMTU: targetSize6, - replyLength: targetSize6 - 1, - }, - { - name: "IPv6 no payload", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: 0, - linkMTU: targetSize6, - replyLength: replyHeaderLength6, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU) - // Allocate and initialize the payload view. - payload := buffer.NewView(test.payloadLength) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - // Default routes for IPv4&6 so ICMP can find a route to the remote - // node when attempting to send the ICMP error Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - v := test.injector(e, test.srcAddress, payload) - pkt, ok := e.Read() - if !ok { - t.Fatal("expected a packet to be written") - } - if got, want := pkt.Pkt.Size(), test.replyLength; got != want { - t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) - } - test.checker(t, pkt.Pkt, v) - }) - } -} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD deleted file mode 100644 index 4b21ee79c..000000000 --- a/pkg/tcpip/network/ipv4/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv4", - srcs = [ - "icmp.go", - "igmp.go", - "ipv4.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/internal/fragmentation", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv4_test", - size = "small", - srcs = [ - "igmp_test.go", - "ipv4_test.go", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stats_test", - size = "small", - srcs = ["stats_test.go"], - library = ":ipv4", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go deleted file mode 100644 index c5f68e411..000000000 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") - multicastAddr = tcpip.Address("\xe0\x00\x00\x03") - nicID = 1 - defaultTTL = 1 - defaultPrefixLength = 24 -) - -// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet -// sent to the provided address with the passed fields set. Raises a t.Error if -// any field does not match. -func validateIgmpPacket(t *testing.T, p channel.PacketInfo, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv4(t, payload, - checker.SrcAddr(srcAddr), - checker.DstAddr(dstAddr), - // TTL for an IGMP message must be 1 as per RFC 2236 section 2. - checker.TTL(1), - checker.IPv4RouterAlert(), - checker.IGMP( - checker.IGMPType(igmpType), - checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), - checker.IGMPGroupAddress(groupAddress), - ), - ) -} - -func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - - // Create an endpoint of queue size 1, since no more than 1 packets are ever - // queued in the tests in this file. - e := channel.New(1, 1280, linkAddr) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMP: ipv4.IGMPOptions{ - Enabled: igmpEnabled, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - return e, s, clock -} - -func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) { - var options header.IPv4OptionsSerializer - if hasRouterAlertOption { - options = header.IPv4OptionsSerializer{ - &header.IPv4SerializableRouterAlertOption{}, - } - } - buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: ttl, - Protocol: uint8(header.IGMPProtocolNumber), - SrcAddr: srcAddr, - DstAddr: dstAddr, - Options: options, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - igmp := header.IGMP(ip.Payload()) - igmp.SetType(igmpType) - igmp.SetMaxRespTime(maxRespTime) - igmp.SetGroupAddress(groupAddress) - igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - - e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 -// router is detected. V1 present status is expected to be reset when the NIC -// cycles. -func TestIGMPV1Present(t *testing.T) { - e, s, clock := createStack(t, true) - addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength} - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) - } - - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // This NIC will send an IGMPv2 report immediately, before this test can get - // the IGMPv1 General Membership Query in. - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Inject an IGMPv1 General Membership Query which is identical to a standard - // membership query except the Max Response Time is set to 0, which will tell - // the stack that this is a router using IGMPv1. Send it to the all systems - // group which is the only group this host belongs to. - createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, header.IPv4AllSystems, true /* hasRouterAlertOption */) - if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { - t.Fatalf("got Membership Queries received = %d, want = 1", got) - } - - // Before advancing the clock, verify that this host has not sent a - // V1MembershipReport yet. - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) - } - - // Verify the solicited Membership Report is sent. Now that this NIC has seen - // an IGMPv1 query, it should send an IGMPv1 Membership Report. - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - - // Cycling the interface should reset the V1 present flag. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } -} - -func TestSendQueuedIGMPReports(t *testing.T) { - e, s, clock := createStack(t, true) - - // Joining a group without an assigned address should queue IGMP packets; none - // should be sent without an assigned address. - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) - } - reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport - if got := reportStat.Value(); got != 0 { - t.Errorf("got reportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } - - // The initial set of IGMP reports that were queued should be sent once an - // address is assigned. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) - } - if got := reportStat.Value(); got != 1 { - t.Errorf("got reportStat.Value() = %d, want = 1", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - if got := reportStat.Value(); got != 2 { - t.Errorf("got reportStat.Value() = %d, want = 2", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Should have no more packets to send after the initial set of unsolicited - // reports. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } -} - -func TestIGMPPacketValidation(t *testing.T) { - tests := []struct { - name string - messageType header.IGMPType - stackAddresses []tcpip.AddressWithPrefix - srcAddr tcpip.Address - includeRouterAlertOption bool - ttl uint8 - expectValidIGMP bool - getMessageTypeStatValue func(tcpip.Stats) uint64 - }{ - { - name: "valid", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "bad ttl", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 2, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "missing router alert ip option", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: false, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "igmp leave group and src ip does not belong to nic subnet", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "igmp query and src ip does not belong to nic subnet", - messageType: header.IGMPMembershipQuery, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() }, - }, - { - name: "igmp report v1 and src ip does not belong to nic subnet", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "igmp report v2 and src ip does not belong to nic subnet", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "src ip belongs to the subnet of the nic's second address", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{ - {Address: tcpip.Address("\x0a\x00\x0f\x01"), PrefixLen: 24}, - {Address: stackAddr, PrefixLen: 24}, - }, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) - for _, address := range test.stackAddresses { - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err) - } - } - stats := s.Stats() - // Verify that every relevant stats is zero'd before we send a packet. - if got := test.getMessageTypeStatValue(s.Stats()); got != 0 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got) - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got) - } - if got := stats.IP.PacketsDelivered.Value(); got != 0 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got) - } - createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption) - // We always expect the packet to pass IP validation. - if got := stats.IP.PacketsDelivered.Value(); got != 1 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got) - } - // Even when the IGMP-specific validation checks fail, we expect the - // corresponding IGMP counter to be incremented. - if got := test.getMessageTypeStatValue(s.Stats()); got != 1 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got) - } - var expectedInvalidCount uint64 - if !test.expectValidIGMP { - expectedInvalidCount = 1 - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go new file mode 100644 index 000000000..87a48e2ce --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go @@ -0,0 +1,105 @@ +// automatically generated by stateify. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *icmpv4DestinationUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationUnreachableSockError" +} + +func (i *icmpv4DestinationUnreachableSockError) StateFields() []string { + return []string{} +} + +func (i *icmpv4DestinationUnreachableSockError) beforeSave() {} + +func (i *icmpv4DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *icmpv4DestinationUnreachableSockError) afterLoad() {} + +func (i *icmpv4DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) { +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnreachableSockError" +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationHostUnreachableSockError) beforeSave() {} + +func (i *icmpv4DestinationHostUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationHostUnreachableSockError) afterLoad() {} + +func (i *icmpv4DestinationHostUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationPortUnreachableSockError" +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationPortUnreachableSockError) beforeSave() {} + +func (i *icmpv4DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) afterLoad() {} + +func (i *icmpv4DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (e *icmpv4FragmentationNeededSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4FragmentationNeededSockError" +} + +func (e *icmpv4FragmentationNeededSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + "mtu", + } +} + +func (e *icmpv4FragmentationNeededSockError) beforeSave() {} + +func (e *icmpv4FragmentationNeededSockError) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.icmpv4DestinationUnreachableSockError) + stateSinkObject.Save(1, &e.mtu) +} + +func (e *icmpv4FragmentationNeededSockError) afterLoad() {} + +func (e *icmpv4FragmentationNeededSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.icmpv4DestinationUnreachableSockError) + stateSourceObject.Load(1, &e.mtu) +} + +func init() { + state.Register((*icmpv4DestinationUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationHostUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationPortUnreachableSockError)(nil)) + state.Register((*icmpv4FragmentationNeededSockError)(nil)) +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go deleted file mode 100644 index dc4db6e5f..000000000 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ /dev/null @@ -1,2987 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "bytes" - "context" - "encoding/hex" - "fmt" - "io/ioutil" - "math" - "net" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/raw" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - extraHeaderReserve = 50 - defaultMTU = 65536 -) - -func TestExcludeBroadcast(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) - if testing.Verbose() { - ep = sniffer.New(ep) - } - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} - - var wq waiter.Queue - t.Run("WithoutPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Cannot connect using a broadcast address as the source. - { - err := ep.Connect(randomAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) - } - } - - // However, we can bind to a broadcast address to listen. - if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { - t.Errorf("Bind failed: %v", err) - } - }) - - t.Run("WithPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := ep.Connect(randomAddr); err != nil { - t.Errorf("Connect failed: %v", err) - } - }) -} - -func TestForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - randomSequence = 123 - randomIdent = 42 - randomTimeOffset = 0x10203040 - ) - - ipv4Addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - } - ipv4Addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), - PrefixLen: 8, - } - remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) - remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) - - tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code - }{ - { - name: "TTL of zero", - TTL: 0, - expectErrorICMP: true, - icmpType: header.ICMPv4TimeExceeded, - icmpCode: header.ICMPv4TTLExceeded, - }, - { - name: "TTL of one", - TTL: 1, - expectErrorICMP: false, - }, - { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, - }, - { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, - }, - { - name: "four EOL options", - TTL: 2, - expectErrorICMP: false, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "TS type 1 full", - TTL: 2, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - expectErrorICMP: true, - icmpType: header.ICMPv4ParamProblem, - icmpCode: header.ICMPv4UnusedCode, - }, - { - name: "TS type 0", - TTL: 2, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - forwardedOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - name: "end of options list", - TTL: 2, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - 1, 2, 3, 4, - }, - forwardedOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 7 bytes unknown option removed. - 0, 0, 0, 0, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} - if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) - } - - e2 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} - if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: ipv4Addr1.Subnet(), - NIC: nicID1, - }, - { - Destination: ipv4Addr2.Subnet(), - NIC: nicID2, - }, - }) - - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) - } - - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - Protocol: uint8(header.ICMPv4ProtocolNumber), - TTL: test.TTL, - SrcAddr: remoteIPv4Addr1, - DstAddr: remoteIPv4Addr2, - }) - if len(test.options) != 0 { - ip.SetHeaderLength(uint8(ipHeaderLength)) - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - - if test.expectErrorICMP { - reply, ok := e1.Read() - if !ok { - t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) - } - - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv4Addr1.Address), - checker.DstAddr(remoteIPv4Addr1), - checker.TTL(ipv4.DefaultTTL), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(test.icmpType), - checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View())), - ), - ) - - if n := e2.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) - } - } else { - reply, ok := e2.Read() - if !ok { - t.Fatal("expected ICMP Echo packet through outgoing NIC") - } - - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv4Addr1), - checker.DstAddr(remoteIPv4Addr2), - checker.TTL(test.TTL-1), - checker.IPv4Options(test.forwardedOptions), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Payload(nil), - ), - ) - - if n := e1.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } - } - }) - } -} - -// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and -// checks the response. -func TestIPv4Sanity(t *testing.T) { - const ( - ttl = 255 - nicID = 1 - randomSequence = 123 - randomIdent = 42 - // In some cases Linux sets the error pointer to the start of the option - // (offset 0) instead of the actual wrong value, which is the length byte - // (offset 1). For compatibility we must do the same. Use this constant - // to indicate where this happens. - pointerOffsetForInvalidLength = 0 - randomTimeOffset = 0x10203040 - ) - var ( - ipv4Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - } - remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) - ) - - tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - options header.IPv4Options - replyOptions header.IPv4Options // reply should look like this - shouldFail bool - expectErrorICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - paramProblemPointer uint8 - }{ - { - name: "valid no options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - }, - { - name: "bad header checksum", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - badHeaderChecksum: true, - shouldFail: true, - }, - // The TTL tests check that we are not rejecting an incoming packet - // with a zero or one TTL, which has been a point of confusion in the - // past as RFC 791 says: "If this field contains the value zero, then the - // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies - // for the case of the destination host, stating as follows. - // - // A host MUST NOT send a datagram with a Time-to-Live (TTL) - // value of zero. - // - // A host MUST NOT discard a datagram just because it was - // received with TTL less than 2. - { - name: "zero TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 0, - }, - { - name: "one TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 1, - }, - { - name: "End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{0, 0, 0, 0}, - replyOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "NOP options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 1, 1}, - replyOptions: header.IPv4Options{1, 1, 1, 1}, - }, - { - name: "NOP and End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 0, 0}, - replyOptions: header.IPv4Options{1, 1, 0, 0}, - }, - { - name: "bad header length", - headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (0)", - maxTotalLength: 0, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip + icmp - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad protocol", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: 99, - TTL: ttl, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4DstUnreachable, - ICMPCode: header.ICMPv4ProtoUnreachable, - }, - { - name: "timestamp option overflow", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - }, - { - name: "timestamp option overflow full", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - // ^ Counter full (15/0xF) - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - replyOptions: header.IPv4Options{}, - }, - { - name: "unknown option", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{10, 4, 9, 0}, - // ^^ - // The unknown option should be stripped out of the reply. - replyOptions: header.IPv4Options{}, - }, - { - name: "bad option - no length", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 1, 1, 1, 68, - // ^-start of timestamp.. but no length.. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - }, - { - name: "bad option - length 0", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 0, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length big", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 9, 9, 0, - // ^ - // There are only 8 bytes allocated to options so 9 bytes of timestamp - // space is not possible. (Second byte) - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - // This tests for some linux compatible behaviour. - // The ICMP pointer returned is 22 for Linux but the - // error is actually in spot 21. - name: "bad option - length bad", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - // Timestamps are in multiples of 4 or 8 but never 7. - // The option space should be padded out. - options: header.IPv4Options{ - 68, 7, 5, 0, - // ^ ^ Linux points here which is wrong. - // | Not a multiple of 4 - 1, 2, 3, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "multiple type 0 with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // The timestamp area is full so add to the overflow count. - name: "multiple type 1 timestamps", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 21, 0x11, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - // Overflow count is the top nibble of the 4th byte. - replyOptions: header.IPv4Options{ - 68, 20, 21, 0x21, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - }, - { - name: "multiple type 1 timestamps with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 28, 21, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 28, 29, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 192, 168, 1, 58, // New IP Address. - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Timestamp pointer uses one based counting so 0 is invalid. - name: "timestamp pointer invalid", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, 0, 0x00, - // ^ 0 instead of 5 or more. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, - }, - { - // Timestamp pointer cannot be less than 5. It must point past the header - // which is 4 bytes. (1 based counting) - name: "timestamp pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, - // ^ header is 4 bytes, so 4 should fail. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "valid timestamp pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, - // ^ header is 4 bytes, so 5 should succeed. - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 8, 9, 0x00, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Needs 8 bytes for a type 1 timestamp but there are only 4 free. - name: "bad timer element alignment", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 17, 0x01, - // ^^ ^^ 20 byte area, next free spot at 17. - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - // End of option list with illegal option after it, which should be ignored. - { - name: "end of options list", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 3 bytes unknown option removed. - }, - }, - { - // Timestamp with a size much too small. - name: "timestamp truncated", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 0, 0, - // ^ Smallest possible is 8. Linux points at the 68. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "single record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 4, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "multiple record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 20, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "single record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Unlike timestamp, this should just succeed. - name: "multiple record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 24, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Pointer uses one based counting so 0 is invalid. - name: "record route pointer zero", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 0, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. 3 should fail. - name: "record route pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. Check 4 passes. (Duplicates "single - // record route with room") - name: "valid record route pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Confirm Linux bug for bug compatibility. - // Linux returns slot 22 but the error is in slot 21. - name: "multiple record route with not enough room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 8, // 3 byte header - // ^ ^ Linux points here. We must too. - // | Not enough room. 1 byte free, need 4. - 1, 2, 3, 4, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - name: "duplicate record route", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, 0, // pad - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 7, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) - } - - // Default routes for IPv4 so ICMP can find a route to the remote - // node when attempting to send the ICMP Echo Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - if len(test.options)%4 != 0 { - t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) - } - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - - // Specify ident/seq to make sure we get the same in the response. - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - if test.maxTotalLength < totalLen { - totalLen = test.maxTotalLength - } - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - Protocol: test.transportProtocol, - TTL: test.TTL, - SrcAddr: remoteIPv4Addr, - DstAddr: ipv4Addr.Address, - }) - if test.headerLength != 0 { - ip.SetHeaderLength(test.headerLength) - } else { - // Set the calculated header length, since we may manually add options. - ip.SetHeaderLength(uint8(ipHeaderLength)) - } - if len(test.options) != 0 { - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ipHeaderChecksum := ip.CalculateChecksum() - if test.badHeaderChecksum { - ipHeaderChecksum += 42 - } - ip.SetChecksum(^ipHeaderChecksum) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - reply, ok := e.Read() - if !ok { - if test.shouldFail { - if test.expectErrorICMP { - t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) - } - return // Expected silent failure. - } - t.Fatal("expected ICMP echo reply missing") - } - - // We didn't expect a packet. Register our surprise but carry on to - // provide more information about what we got. - if test.shouldFail && !test.expectErrorICMP { - t.Error("unexpected packet response") - } - - // Check the route that brought the packet to us. - if reply.Route.LocalAddress != ipv4Addr.Address { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) - } - if reply.Route.RemoteAddress != remoteIPv4Addr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) - } - - // Make sure it's all in one buffer for checker. - replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - - // At this stage we only know it's probably an IP+ICMP header so verify - // that much. - checker.IPv4(t, replyIPHeader, - checker.SrcAddr(ipv4Addr.Address), - checker.DstAddr(remoteIPv4Addr), - checker.ICMPv4( - checker.ICMPv4Checksum(), - ), - ) - - // Don't proceed any further if the checker found problems. - if t.Failed() { - t.FailNow() - } - - // OK it's ICMP. We can safely look at the type now. - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - switch replyICMPHeader.Type() { - case header.ICMPv4ParamProblem: - if !test.shouldFail { - t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) - } - if !test.expectErrorICMP { - t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), - ), - ) - return - case header.ICMPv4DstUnreachable: - if !test.shouldFail { - t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - if !test.expectErrorICMP { - t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), - ), - ) - return - case header.ICMPv4EchoReply: - if test.shouldFail { - if !test.expectErrorICMP { - t.Error("got Echo Reply packet, want no response") - } else { - t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) - } - } - // If the IP options change size then the packet will change size, so - // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - len(test.options) - - checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength+sizeChange), - checker.IPv4Options(test.replyOptions), - checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Seq(randomSequence), - checker.ICMPv4Ident(randomIdent), - ), - ) - default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) - } - }) - } -} - -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { - // Make a complete array of the sourcePacket packet. - source := header.IPv4(packets[0].NetworkHeader().View()) - vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) - - // Make a copy of the IP header, which will be modified in some fields to make - // an expected header. - sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) - sourceCopy.SetChecksum(0) - sourceCopy.SetFlagsFragmentOffset(0, 0) - sourceCopy.SetTotalLength(0) - // Build up an array of the bytes sent. - var reassembledPayload buffer.VectorisedView - for i, packet := range packets { - // Confirm that the packet is valid. - allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - fragmentIPHeader := header.IPv4(allBytes.ToView()) - if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { - return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) - } - if got := len(fragmentIPHeader); got > int(mtu) { - return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) - } - if got := fragmentIPHeader.TransportProtocol(); got != proto { - return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) - } - if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } - if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { - return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) - } - if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { - return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) - } - if wantFragments[i].more { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) - } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) - } - reassembledPayload.AppendView(packet.TransportHeader().View()) - reassembledPayload.Append(packet.Data) - // Clear out the checksum and length from the ip because we can't compare - // it. - sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) - sourceCopy.SetChecksum(0) - sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) - } - } - - expected := buffer.View(source[source.HeaderLength():]) - if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { - return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - - return nil -} - -type fragmentInfo struct { - offset uint16 - more bool - payloadSize uint16 -} - -var fragmentationTests = []struct { - description string - mtu uint32 - gso *stack.GSO - transportHeaderLength int - payloadSize int - wantFragments []fragmentInfo -}{ - { - description: "No fragmentation", - mtu: 1280, - gso: nil, - transportHeaderLength: 0, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1000, more: false}, - }, - }, - { - description: "Fragmented", - mtu: 1280, - gso: nil, - transportHeaderLength: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 744, more: false}, - }, - }, - { - description: "Fragmented with the minimum mtu", - mtu: header.IPv4MinimumMTU, - gso: nil, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "Fragmented with mtu not a multiple of 8", - mtu: header.IPv4MinimumMTU + 1, - gso: nil, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "No fragmentation with big header", - mtu: 2000, - gso: nil, - transportHeaderLength: 100, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1100, more: false}, - }, - }, - { - description: "Fragmented with gso none", - mtu: 1280, - gso: &stack.GSO{Type: stack.GSONone}, - transportHeaderLength: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 144, more: false}, - }, - }, - { - description: "Fragmented with big header", - mtu: 1280, - gso: nil, - transportHeaderLength: 100, - payloadSize: 1200, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 44, more: false}, - }, - }, - { - description: "Fragmented with MTU smaller than header", - mtu: 300, - gso: nil, - transportHeaderLength: 1000, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 280, more: true}, - {offset: 280, payloadSize: 280, more: true}, - {offset: 560, payloadSize: 280, more: true}, - {offset: 840, payloadSize: 280, more: true}, - {offset: 1120, payloadSize: 280, more: true}, - {offset: 1400, payloadSize: 100, more: false}, - }, - }, -} - -func TestFragmentationWritePacket(t *testing.T) { - const ttl = 42 - - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - source := pkt.Clone() - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if err != nil { - t.Fatalf("r.WritePacket(_, _, _) = %s", err) - } - if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } -} - -func TestFragmentationWritePackets(t *testing.T) { - const ttl = 42 - writePacketsTests := []struct { - description string - insertBefore int - insertAfter int - }{ - { - description: "Single packet", - insertBefore: 0, - insertAfter: 0, - }, - { - description: "With packet before", - insertBefore: 1, - insertAfter: 0, - }, - { - description: "With packet after", - insertBefore: 0, - insertAfter: 1, - }, - { - description: "With packet before and after", - insertBefore: 1, - insertAfter: 1, - }, - } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) - - for _, test := range writePacketsTests { - t.Run(test.description, func(t *testing.T) { - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - var pkts stack.PacketBufferList - for i := 0; i < test.insertBefore; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - pkts.PushBack(pkt.Clone()) - for i := 0; i < test.insertAfter; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - - wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }) - if err != nil { - t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) - } - if n != wantTotalPackets { - t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) - } - if got := len(ep.WrittenPackets); got != wantTotalPackets { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - - if wantTotalPackets == 0 { - return - } - - fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } - }) - } -} - -// TestFragmentationErrors checks that errors are returned from WritePacket -// correctly. -func TestFragmentationErrors(t *testing.T) { - const ttl = 42 - - tests := []struct { - description string - mtu uint32 - transportHeaderLength int - payloadSize int - allowPackets int - outgoingErrors int - mockError tcpip.Error - wantError tcpip.Error - }{ - { - description: "No frag", - mtu: 2000, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 1, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 3, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on second frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 1, - outgoingErrors: 2, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag MTU smaller than header", - mtu: 500, - transportHeaderLength: 1000, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 4, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error when MTU is smaller than IPv4 minimum MTU", - mtu: header.IPv4MinimumMTU - 1, - transportHeaderLength: 0, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrInvalidEndpointState{}, - }, - } - - for _, ft := range tests { - t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) - r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if diff := cmp.Diff(ft.wantError, err); diff != "" { - t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { - t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) - } - }) - } -} - -func TestInvalidFragments(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" - tos = 0 - ident = 1 - ttl = 48 - protocol = 6 - ) - - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload - } - - type fragmentData struct { - ipv4fields header.IPv4Fields - // 0 means insert the correct IHL. Non 0 means override the correct IHL. - overrideIHL int // For 0 use 1 as it is an int and will be divided by 4. - payload []byte - autoChecksum bool // If true, the Checksum field will be overwritten. - } - - tests := []struct { - name string - fragments []fragmentData - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - }{ - { - name: "IHL and TotalLength zero, FragmentOffset non-zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, - FragmentOffset: 59776, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - name: "IHL and TotalLength zero, FragmentOffset zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - // Payload 17 octets and Fragment offset 65520 - // Leading to the fragment end to be past 65536. - name: "fragment ends past 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 17, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(17), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - // Payload 16 octets and fragment offset 65520 - // Leading to the fragment end to be exactly 65536. - name: "fragment ends exactly at 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(16), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 0, - wantMalformedFragments: 0, - }, - { - name: "IHL less than IPv4 minimum size", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 1944, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize - 12, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 2, - wantMalformedFragments: 0, - }, - { - name: "fragment with short TotalLength and extra payload", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 28816, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 4, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - name: "multiple fragments with More Fragments flag set to false", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 128, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - }) - e := channel.New(0, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize + len(f.payload) - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got { - t.Fatalf("copied %d bytes, expected %d bytes.", got, want) - } - // Encode sets this up correctly. If we want a different value for - // testing then we need to overwrite the good value. - if f.overrideIHL != 0 { - ip.SetHeaderLength(uint8(f.overrideIHL)) - // If we are asked to add options (type not specified) then pad - // with 0 (EOL). RFC 791 page 23 says "The padding is zero". - for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ { - ip[i] = byte(header.IPv4OptionListEndType) - } - } - - if f.autoChecksum { - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - } - - vv := hdr.View().ToVectorisedView() - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { - t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { - t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) - } - }) - } -} - -func TestFragmentReassemblyTimeout(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" - tos = 0 - ident = 1 - ttl = 48 - protocol = 99 - data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" - ) - - type fragmentData struct { - ipv4fields header.IPv4Fields - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - expectICMP bool - }{ - { - name: "first fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "two first fragments", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "second fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: false, - }, - { - name: "two fragments with a gap", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: true, - }, - { - name: "two fragments with a gap in reverse order", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - }, - expectICMP: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - Clock: clock, - }) - e := channel.New(1, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }}) - - var firstFragmentSent buffer.View - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if firstFragmentSent == nil && ip.FragmentOffset() == 0 { - firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(header.IPv4ProtocolNumber, pkt) - } - - clock.Advance(ipv4.ReassembleTimeout) - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - if firstFragmentSent == nil { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4TimeExceeded), - checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), - checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), - ), - ) - }) - } -} - -// TestReceiveFragments feeds fragments in through the incoming packet path to -// test reassembly -func TestReceiveFragments(t *testing.T) { - const ( - nicID = 1 - - addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 - ) - - // Build and return a UDP header containing payload. - udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { - payload := buffer.NewView(payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) * multiplier - } - - udpLength := header.UDPMinimumSize + len(payload) - - hdr := buffer.NewPrependable(udpLength) - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) - sum = header.Checksum(payload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - return hdr.View() - } - - // UDP header plus a payload of 0..256 - ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) - udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] - ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) - udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 2. - ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) - udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 3. - // Used to test cases where the fragment blocks are not a multiple of - // the fragment block size of 8 (RFC 791 section 3.1 page 14). - ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) - udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] - // Used to test the max reassembled IPv4 payload length. - ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) - udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] - - type fragmentData struct { - srcAddr tcpip.Address - dstAddr tcpip.Address - id uint16 - flags uint8 - fragmentOffset uint16 - payload buffer.View - } - - tests := []struct { - name string - fragments []fragmentData - expectedPayloads [][]byte - }{ - { - name: "No fragmentation", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "No fragmentation with size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "More fragments without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Non-zero fragment offset without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 8, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments out of order", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with last fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload3Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments with first fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:63], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 63, - payload: ipv4Payload3Addr1ToAddr2[63:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Second fragment has MoreFlags set", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with different IDs", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two interleaved fragmented packets", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload2Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload2Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets from different sources but with same ID", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr3ToAddr2[:32], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 32, - payload: ipv4Payload1Addr3ToAddr2[32:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, - }, - { - name: "Fragment without followup", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, - }, - { - name: "Two fragments with MF flag reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Setup a stack and endpoint. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - RawFactory: raw.EndpointFactory{}, - }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - // Bring up a raw endpoint so we can examine network headers. - epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) - if err != nil { - t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer epRaw.Close() - - // Prepare and send the fragments. - for _, frag := range test.fragments { - hdr := buffer.NewPrependable(header.IPv4MinimumSize) - - // Serialize IPv4 fixed header. - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), - ID: frag.id, - Flags: frag.flags, - FragmentOffset: frag.fragmentOffset, - TTL: 64, - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: frag.srcAddr, - DstAddr: frag.dstAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(frag.payload) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { - t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) - } - - for i, expectedPayload := range test.expectedPayloads { - // Check UDP payload delivered by UDP endpoint. - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) ep.Read: %s", i, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(expectedPayload), - Total: len(expectedPayload), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) - } - if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) - } - - // Check IPv4 header in packet delivered by raw endpoint. - buf.Reset() - result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) epRaw.Read: %s", i, err) - } - // Reassambly does not take care of checksum. Here we write our own - // check routine instead of using checker.IPv4. - ip := header.IPv4(buf.Bytes()) - for _, check := range []checker.NetworkChecker{ - checker.FragmentFlags(0), - checker.FragmentOffset(0), - checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), - } { - check(t, []header.Network{ip}) - } - } - - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -func TestWriteStats(t *testing.T) { - const nPackets = 3 - - tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int - }{ - { - name: "Accept all", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, - }, { - name: "Accept all with error", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, - }, { - name: "Drop all", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, - }, { - name: "Drop some", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule that matches only 1 - // of the 3 packets. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, - }, - } - - // Parameterize the tests to run with both WritePacket and WritePackets. - writers := []struct { - name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) - }{ - { - name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - nWritten := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { - return nWritten, err - } - nWritten++ - } - return nWritten, nil - }, - }, { - name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) - }, - }, - } - - for _, writer := range writers { - t.Run(writer.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) - rt := buildRoute(t, ep) - - var pkts stack.PacketBufferList - for i := 0; i < nPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(header.UDPMinimumSize) - pkts.PushBack(pkt) - } - - test.setup(t, rt.Stack()) - - nWritten, _ := writer.writePackets(rt, pkts) - - if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) - } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) - } - if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) - } - }) - } - }) - } -} - -func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC(1, _) failed: %s", err) - } - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) - } - { - mask := tcpip.AddressMask(header.IPv4Broadcast) - subnet, err := tcpip.NewSubnet(dst, mask) - if err != nil { - t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) - } - return rt -} - -// limitedMatcher is an iptables matcher that matches after a certain number of -// packets are checked against it. -type limitedMatcher struct { - limit int -} - -// Name implements Matcher.Name. -func (*limitedMatcher) Name() string { - return "limitedMatcher" -} - -// Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { - if lm.limit == 0 { - return true, false - } - lm.limit-- - return false, false -} - -func TestPacketQueing(t *testing.T) { - const nicID = 1 - - var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - - host1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - host2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 8, - }, - } - ) - - tests := []struct { - name string - rxPkt func(*channel.Endpoint) - checkResp func(*testing.T, *channel.Endpoint) - }{ - { - name: "ICMP Error", - rxPkt: func(e *channel.Endpoint) { - hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, - TTL: ipv4.DefaultTTL, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - }, - }, - - { - name: "Ping", - rxPkt: func(e *channel.Endpoint) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ipv4.DefaultTTL, - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply), - checker.ICMPv4Code(header.ICMPv4UnusedCode))) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(1, defaultMTU, host1NICLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - - // Receive a packet to trigger link resolution before a response is sent. - test.rxPkt(e) - - // Wait for a ARP request since link address resolution should be - // performed. - { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != arp.ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) - } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - rep := header.ARP(p.Pkt.NetworkHeader().View()) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) - } - } - - // Send an ARP reply to complete link address resolution. - { - hdr := buffer.View(make([]byte, header.ARPSize)) - packet := header.ARP(hdr) - packet.SetIPv4OverEthernet() - packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), host2NICLinkAddr) - copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) - copy(packet.HardwareAddressTarget(), host1NICLinkAddr) - copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) - e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.ToVectorisedView(), - })) - } - - // Expect the response now that the link address has resolved. - test.checkResp(t, e) - - // Since link resolution was already performed, it shouldn't be performed - // again. - test.rxPkt(e) - test.checkResp(t, e) - }) - } -} diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go deleted file mode 100644 index a637f9d50..000000000 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4 - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface - nicID tcpip.NICID -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEP := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // expected to be bound by a MultiCounterStat. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD deleted file mode 100644 index bb9a02ed0..000000000 --- a/pkg/tcpip/network/ipv6/BUILD +++ /dev/null @@ -1,70 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv6", - srcs = [ - "dhcpv6configurationfromndpra_string.go", - "icmp.go", - "ipv6.go", - "mld.go", - "ndp.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/internal/fragmentation", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv6_test", - size = "small", - srcs = [ - "icmp_test.go", - "ipv6_test.go", - "ndp_test.go", - ], - library = ":ipv6", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "ipv6_x_test", - size = "small", - srcs = ["mld_test.go"], - deps = [ - ":ipv6", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go deleted file mode 100644 index 69c1e4bea..000000000 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ /dev/null @@ -1,1646 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "bytes" - "context" - "net" - "reflect" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - nicID = 1 - - linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second -) - -var ( - lladdr0 = header.LinkLocalAddr(linkAddr0) - lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) -) - -type stubLinkEndpoint struct { - stack.LinkEndpoint -} - -func (*stubLinkEndpoint) MTU() uint32 { - return defaultMTU -} - -func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - // Indicate that resolution for link layer addresses is required to send - // packets over this link. This is needed so the NIC knows to allocate a - // neighbor table. - return stack.CapabilityResolutionRequired -} - -func (*stubLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - return nil -} - -func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -type stubDispatcher struct { - stack.TransportDispatcher -} - -func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { - return stack.TransportPacketHandled -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.LinkEndpoint - - probeCount int - confirmationCount int - - nicID tcpip.NICID -} - -func (*testInterface) ID() tcpip.NICID { - return nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (*testInterface) Enabled() bool { - return true -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (*testInterface) Spoofing() bool { - return false -} - -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) -} - -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) -} - -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - var r stack.RouteInfo - r.NetProto = protocol - r.RemoteLinkAddress = remoteLinkAddr - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) -} - -func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { - t.probeCount++ - return nil -} - -func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { - t.confirmationCount++ - return nil -} - -func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { - return false -} - -func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) { - ip := buffer.NewView(header.IPv6MinimumSize) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: src, - DstAddr: dst, - }) - vv := ip.ToVectorisedView() - vv.AppendView(buffer.View(icmp)) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: vv, - })) -} - -func TestICMPCounts(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - typ header.ICMPv6Type - size int - extraData []byte - }{ - { - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - }, - { - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - }, - { - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - }, - { - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, - { - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - }, - { - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6MulticastListenerQuery, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerReport, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerDone, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: 255, /* Unrecognized */ - size: 50, - }, - } - - for _, typ := range types { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - - icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } -} - -func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { - t := v.Type() - for i := 0; i < v.NumField(); i++ { - v := v.Field(i) - if s, ok := v.Interface().(*tcpip.StatCounter); ok { - f(t.Field(i).Name, s) - } else { - visitStats(v, f) - } - } -} - -type testContext struct { - s0 *stack.Stack - s1 *stack.Stack - - linkEP0 *channel.Endpoint - linkEP1 *channel.Endpoint -} - -type endpointWithResolutionCapability struct { - stack.LinkEndpoint -} - -func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { - return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired -} - -func newTestContext(t *testing.T) *testContext { - c := &testContext{ - s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }), - s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }), - } - - c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) - - wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) - if testing.Verbose() { - wrappedEP0 = sniffer.New(wrappedEP0) - } - if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { - t.Fatalf("CreateNIC s0: %v", err) - } - if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) - } - - c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) - wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) - } - - subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - c.s0.SetRouteTable( - []tcpip.Route{{ - Destination: subnet0, - NIC: nicID, - }}, - ) - subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) - if err != nil { - t.Fatal(err) - } - c.s1.SetRouteTable( - []tcpip.Route{{ - Destination: subnet1, - NIC: nicID, - }}, - ) - - t.Cleanup(func() { - if err := c.s0.RemoveNIC(nicID); err != nil { - t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err) - } - if err := c.s1.RemoveNIC(nicID); err != nil { - t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err) - } - - c.linkEP0.Close() - c.linkEP1.Close() - }) - - return c -} - -type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type - remoteLinkAddr tcpip.LinkAddress -} - -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { - t.Helper() - - pi, _ := args.src.ReadContext(context.Background()) - - { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), - }) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) - } - - if pi.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pi.Proto) - return - } - - if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) - } - - // Pull the full payload since network header. Needed for header.IPv6 to - // extract its payload. - ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) - transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) - if transProto != header.ICMPv6ProtocolNumber { - t.Errorf("unexpected transport protocol number %d", transProto) - return - } - icmpv6 := header.ICMPv6(ipv6.Payload()) - if got, want := icmpv6.Type(), args.typ; got != want { - t.Errorf("got ICMPv6 type = %d, want = %d", got, want) - return - } - if fn != nil { - fn(t, icmpv6) - } -} - -func TestLinkResolution(t *testing.T) { - c := newTestContext(t) - - r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) - } - defer r.Release() - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - // We can't send our payload directly over the route because that - // doesn't provoke NDP discovery. - var wq waiter.Queue - ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) - } - - { - var r bytes.Reader - r.Reset(hdr.View()) - if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { - t.Fatalf("ep.Write(_): %s", err) - } - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) - } - - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, - } { - routeICMPv6Packet(t, args, nil) - } -} - -func TestICMPChecksumValidationSimple(t *testing.T) { - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "DstUnreachable", - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - }, - { - name: "PacketTooBig", - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - }, - { - name: "TimeExceeded", - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - }, - { - name: "ParamProblem", - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - }, - { - name: "EchoRequest", - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - }, - { - name: "EchoReply", - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - }, - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - // Hosts MUST silently discard any received Router Solicitation messages. - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } - t.Run(name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr0) - - // Indicate that resolution for link layer addresses is required to - // send packets over this link. This is needed so the NIC knows to - // allocate a neighbor table. - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) - } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Router only count should not have increased. - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if !isRouter && typ.routerOnly { - // Router only count should have increased. - if got := routerOnly.Value(); got != 1 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) - } - } - }) - } - } -} - -func TestICMPChecksumValidationWithPayload(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - icmpSize := size + payloadSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) - icmpHdr.SetType(typ) - payloadFn(icmpHdr.Payload()) - - if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got = %d, want = 0", got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - icmpHdr := header.ICMPv6(hdr.Prepend(size)) - icmpHdr.SetType(typ) - - payload := buffer.NewView(payloadSize) - payloadFn(payload) - - if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got = %d, want = 0", got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestLinkAddressRequest(t *testing.T) { - const nicID = 1 - - snaddr := header.SolicitedNodeAddr(lladdr0) - mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) - - tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - - expectedErr tcpip.Error - expectedRemoteAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - }{ - { - name: "Unicast", - nicAddr: lladdr1, - localAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedRemoteAddr: lladdr0, - expectedRemoteLinkAddr: linkAddr1, - }, - { - name: "Multicast", - nicAddr: lladdr1, - localAddr: lladdr1, - remoteLinkAddr: "", - expectedRemoteAddr: snaddr, - expectedRemoteLinkAddr: mcaddr, - }, - { - name: "Unicast with unspecified source", - nicAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedRemoteAddr: lladdr0, - expectedRemoteLinkAddr: linkAddr1, - }, - { - name: "Multicast with unspecified source", - nicAddr: lladdr1, - remoteLinkAddr: "", - expectedRemoteAddr: snaddr, - expectedRemoteLinkAddr: mcaddr, - }, - { - name: "Unicast with unassigned address", - localAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "Multicast with unassigned address", - localAddr: lladdr1, - remoteLinkAddr: "", - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "Unicast with no local address available", - remoteLinkAddr: linkAddr1, - expectedErr: &tcpip.ErrNetworkUnreachable{}, - }, - { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: &tcpip.ErrNetworkUnreachable{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - - linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := s.CreateNIC(nicID, linkEP); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) - } - linkRes, ok := ep.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) - } - - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) - } - } - - { - err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) - } - } - - if test.expectedErr != nil { - return - } - - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = test.expectedRemoteLinkAddr - if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), - checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedRemoteAddr), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), - )) - }) - } -} - -func TestPacketQueing(t *testing.T) { - const nicID = 1 - - var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - - host1IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::1").To16()), - PrefixLen: 64, - }, - } - host2IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::2").To16()), - PrefixLen: 64, - }, - } - ) - - tests := []struct { - name string - rxPkt func(*channel.Endpoint) - checkResp func(*testing.T, *channel.Endpoint) - }{ - { - name: "ICMP Error", - rxPkt: func(e *channel.Endpoint) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: udp.ProtocolNumber, - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - }, - }, - - { - name: "Ping", - rxPkt: func(e *channel.Endpoint) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoReply), - checker.ICMPv6Code(header.ICMPv6UnusedCode))) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - - // Receive a packet to trigger link resolution before a response is sent. - test.rxPkt(e) - - // Wait for a neighbor solicitation since link address resolution should - // be performed. - { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}), - )) - } - - // Send a neighbor advertisement to complete link address resolution. - { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) - pkt := header.ICMPv6(hdr.Prepend(naSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), - }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - // Expect the response now that the link address has resolved. - test.checkResp(t, e) - - // Since link resolution was already performed, it shouldn't be performed - // again. - test.rxPkt(e) - test.checkResp(t, e) - }) - } -} - -func TestCallsToNeighborCache(t *testing.T) { - tests := []struct { - name string - createPacket func() header.ICMPv6 - multicast bool - source tcpip.Address - destination tcpip.Address - wantProbeCount int - wantConfirmationCount int - }{ - { - name: "Unicast Neighbor Solicitation without source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - return icmp - }, - source: lladdr1, - destination: lladdr0, - // "The source link-layer address option SHOULD be included in unicast - // solicitations." - RFC 4861 section 4.3 - // - // A Neighbor Advertisement needs to be sent in response, but the - // Neighbor Cache shouldn't be updated since we have no useful - // information about the sender. - wantProbeCount: 0, - }, - { - name: "Unicast Neighbor Solicitation with source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: lladdr0, - wantProbeCount: 1, - }, - { - name: "Multicast Neighbor Solicitation without source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - return icmp - }, - source: lladdr1, - destination: header.SolicitedNodeAddr(lladdr0), - // "The source link-layer address option MUST be included in multicast - // solicitations." - RFC 4861 section 4.3 - wantProbeCount: 0, - }, - { - name: "Multicast Neighbor Solicitation with source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: header.SolicitedNodeAddr(lladdr0), - wantProbeCount: 1, - }, - { - name: "Unicast Neighbor Advertisement without target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - return icmp - }, - source: lladdr1, - destination: lladdr0, - // "When responding to unicast solicitations, the target link-layer - // address option can be omitted since the sender of the solicitation has - // the correct link-layer address; otherwise, it would not be able to - // send the unicast solicitation in the first place." - // - RFC 4861 section 4.4 - wantConfirmationCount: 1, - }, - { - name: "Unicast Neighbor Advertisement with target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: lladdr0, - wantConfirmationCount: 1, - }, - { - name: "Multicast Neighbor Advertisement without target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(false) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - return icmp - }, - source: lladdr1, - destination: header.IPv6AllNodesMulticastAddress, - // "Target link-layer address MUST be included for multicast solicitations - // in order to avoid infinite Neighbor Solicitation "recursion" when the - // peer node does not have a cache entry to return a Neighbor - // Advertisements message." - RFC 4861 section 4.4 - wantConfirmationCount: 0, - }, - { - name: "Multicast Neighbor Advertisement with target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(false) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: header.IPv6AllNodesMulticastAddress, - wantConfirmationCount: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)} - ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) - handleICMPInIPv6(ep, test.source, test.destination, icmp) - - // Confirm the endpoint calls the correct NUDHandler method. - if testInterface.probeCount != test.wantProbeCount { - t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount) - } - if testInterface.confirmationCount != test.wantConfirmationCount { - t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go new file mode 100644 index 000000000..675fdc220 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go @@ -0,0 +1,126 @@ +// automatically generated by stateify. + +package ipv6 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *icmpv6DestinationUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationUnreachableSockError" +} + +func (i *icmpv6DestinationUnreachableSockError) StateFields() []string { + return []string{} +} + +func (i *icmpv6DestinationUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *icmpv6DestinationUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) { +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationNetworkUnreachableSockError" +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationPortUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationPortUnreachableSockError" +} + +func (i *icmpv6DestinationPortUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationPortUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationPortUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationAddressUnreachableSockError" +} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationAddressUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationAddressUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (e *icmpv6PacketTooBigSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6PacketTooBigSockError" +} + +func (e *icmpv6PacketTooBigSockError) StateFields() []string { + return []string{ + "mtu", + } +} + +func (e *icmpv6PacketTooBigSockError) beforeSave() {} + +func (e *icmpv6PacketTooBigSockError) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.mtu) +} + +func (e *icmpv6PacketTooBigSockError) afterLoad() {} + +func (e *icmpv6PacketTooBigSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.mtu) +} + +func init() { + state.Register((*icmpv6DestinationUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationNetworkUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationPortUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationAddressUnreachableSockError)(nil)) + state.Register((*icmpv6PacketTooBigSockError)(nil)) +} diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go deleted file mode 100644 index 7e714b50e..000000000 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ /dev/null @@ -1,3089 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "bytes" - "encoding/hex" - "fmt" - "io/ioutil" - "math" - "net" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // The least significant 3 bytes are the same as addr2 so both addr2 and - // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" - addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" - - // Tests use the extension header identifier values as uint8 instead of - // header.IPv6ExtensionHeaderIdentifier. - hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier) - routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier) - fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) - destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) - noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) - unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier) - - extraHeaderReserve = 50 -) - -// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the -// expected Neighbor Advertisement received count after receiving the packet. -func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - stats := s.Stats().ICMP.V6.PacketsReceived - - if got := stats.NeighborAdvert.Value(); got != want { - t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) - } -} - -// testReceiveUDP tests receiving a UDP packet from src to dst. want is the -// expected UDP received count after receiving the packet. -func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - - // Receive UDP Packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - - // UDP pseudo-header checksum. - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) - - // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - stat := s.Stats().UDP.PacketsReceived - - if got := stat.Value(); got != want { - t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) - } -} - -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { - // sourcePacket does not have its IP Header populated. Let's copy the one - // from the first fragment. - source := header.IPv6(packets[0].NetworkHeader().View()) - sourceIPHeadersLen := len(source) - vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) - - var reassembledPayload buffer.VectorisedView - for i, fragment := range packets { - // Confirm that the packet is valid. - allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views()) - fragmentIPHeaders := header.IPv6(allBytes.ToView()) - if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) { - return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders)) - } - - fragmentIPHeadersLength := fragment.NetworkHeader().View().Size() - if fragmentIPHeadersLength != sourceIPHeadersLen { - return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen) - } - - if got := len(fragmentIPHeaders); got > int(mtu) { - return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu) - } - - sourceIPHeader := source[:header.IPv6MinimumSize] - fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize] - - if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize { - return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize) - } - - // We expect the IPv6 Header to be similar across each fragment, besides the - // payload length. - sourceIPHeader.SetPayloadLength(0) - fragmentIPHeader.SetPayloadLength(0) - if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) - } - - if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } - if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { - return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) - } - - if len(packets) > 1 { - // If the source packet was big enough that it needed fragmentation, let's - // inspect the fragment header. Because no other extension headers are - // supported, it will always be the last extension header. - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength]) - - if got := fragmentHeader.More(); got != wantFragments[i].more { - return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more) - } - if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset { - return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset) - } - if got := fragmentHeader.NextHeader(); got != uint8(proto) { - return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto)) - } - } - - // Store the reassembled payload as we parse each fragment. The payload - // includes the Transport header and everything after. - reassembledPayload.AppendView(fragment.TransportHeader().View()) - reassembledPayload.Append(fragment.Data) - } - - if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { - return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - - return nil -} - -// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and -// UDP packets destined to the IPv6 link-local all-nodes multicast address. -func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocolFactory - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6, testReceiveICMP}, - {"UDP", udp.NewProtocol, testReceiveUDP}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, - }) - e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should receive a packet destined to the all-nodes - // multicast address. - test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) - }) - } -} - -// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP -// packets destined to the IPv6 solicited-node address of an assigned IPv6 -// address. -func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocolFactory - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6, testReceiveICMP}, - {"UDP", udp.NewProtocol, testReceiveUDP}, - } - - snmc := header.SolicitedNodeAddr(addr2) - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, - }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - // Should not receive a packet destined to the solicited node address of - // addr2/addr3 yet as we haven't added those addresses. - test.rxf(t, s, e, addr1, snmc, 0) - - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - // Should receive a packet destined to the solicited node address of - // addr2/addr3 now that we have added added addr2. - test.rxf(t, s, e, addr1, snmc, 1) - - if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) - } - - // Should still receive a packet destined to the solicited node address of - // addr2/addr3 now that we have added addr3. - test.rxf(t, s, e, addr1, snmc, 2) - - if err := s.RemoveAddress(nicID, addr2); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err) - } - - // Should still receive a packet destined to the solicited node address of - // addr2/addr3 now that we have removed addr2. - test.rxf(t, s, e, addr1, snmc, 3) - - // Make sure addr3's endpoint does not get removed from the NIC by - // incrementing its reference count with a route. - r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err) - } - defer r.Release() - - if err := s.RemoveAddress(nicID, addr3); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err) - } - - // Should not receive a packet destined to the solicited node address of - // addr2/addr3 yet as both of them got removed, even though a route using - // addr3 exists. - test.rxf(t, s, e, addr1, snmc, 3) - }) - } -} - -// TestAddIpv6Address tests adding IPv6 addresses. -func TestAddIpv6Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - }{ - // This test is in response to b/140943433. - { - "Nil", - tcpip.Address([]byte(nil)), - }, - { - "ValidUnicast", - addr1, - }, - { - "ValidLinkLocalUnicast", - lladdr0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) - } - - if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok { - t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber) - } else if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr) - } - }) - } -} - -func TestReceiveIPv6ExtHdrs(t *testing.T) { - tests := []struct { - name string - extHdr func(nextHdr uint8) ([]byte, uint8) - shouldAccept bool - // Should we expect an ICMP response and if so, with what contents? - expectICMP bool - ICMPType header.ICMPv6Type - ICMPCode header.ICMPv6Code - pointer uint32 - multicast bool - }{ - { - name: "None", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "hopbyhop with unknown option skippable action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 62, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "hopbyhop with unknown option discard action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard unknown. - 127, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "hopbyhop with unknown option discard and send icmp action (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: false, - }, - { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 0, 2, 3, 4, 5, - }, routingExtHdrID - }, - shouldAccept: true, - }, - { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 1, 2, 3, 4, 5, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6ErroneousHeader, - pointer: header.IPv6FixedHeaderSize + 2, - }, - { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 0, 0, 0, 0, 0, 0, - }, fragmentExtHdrID - }, - shouldAccept: true, - }, - { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 0, 0, 1, 2, 3, 4, - }, fragmentExtHdrID - }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 0, 1, 2, 3, 4, - }, fragmentExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{}, - noNextHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "unknown next header (first)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, unknownHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6NextHeaderOffset, - }, - { - name: "unknown next header (not first)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - unknownHdrID, 0, - 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "destination with unknown option skippable action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 62, 6, 1, 2, 3, 4, 5, 6, - }, destinationExtHdrID - }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "destination with unknown option discard action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard unknown. - 127, 6, 1, 2, 3, 4, 5, 6, - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "destination with unknown option discard and send icmp action (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ 191 is an unknown option. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action (muilticast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ 191 is an unknown option. - }, destinationExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ 255 is unknown. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ 255 is unknown. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: false, - multicast: true, - }, - { - name: "atomic fragment - routing", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Fragment extension header. - routingExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Routing extension header. - nextHdr, 0, 1, 0, 2, 3, 4, 5, - }, fragmentExtHdrID - }, - shouldAccept: true, - }, - { - name: "hop by hop (with skippable unknown) - routing", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - nextHdr, 0, 1, 0, 2, 3, 4, 5, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "routing - hop by hop (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Routing extension header. - hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, - // ^^^ The HopByHop extension header may not appear after the first - // extension header. - - // Hop By Hop extension header with skippable unknown option. - nextHdr, 0, 62, 4, 1, 2, 3, 4, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "routing - hop by hop (with send icmp unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Routing extension header. - hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, - // ^^^ The HopByHop extension header may not appear after the first - // extension header. - - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with skippable unknown option. - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with discard action for unknown option. - routingExtHdrID, 0, 65, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with skippable unknown option. - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with discard action for unknown - // option. - nextHdr, 0, 65, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - // Add a default route so that a return packet knows where to go. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8} - udpLength := header.UDPMinimumSize + len(udpPayload) - extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber)) - extHdrLen := len(extHdrBytes) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength) - - // Serialize UDP message. - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), udpPayload) - - dstAddr := tcpip.Address(addr2) - if test.multicast { - dstAddr = header.IPv6AllNodesMulticastAddress - } - - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength)) - sum = header.Checksum(udpPayload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - // Copy extension header bytes between the UDP message and the IPv6 - // fixed header. - copy(hdr.Prepend(extHdrLen), extHdrBytes) - - // Serialize IPv6 fixed header. - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - // We're lying about transport protocol here to be able to generate - // raw extension headers from the test definitions. - TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, - }) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - stats := s.Stats().UDP.PacketsReceived - - if !test.shouldAccept { - if got := stats.Value(); got != 0 { - t.Errorf("got UDP Rx Packets = %d, want = 0", got) - } - - if !test.expectICMP { - if p, ok := e.Read(); ok { - t.Fatalf("unexpected packet received: %#v", p) - } - return - } - - // ICMP required. - p, ok := e.Read() - if !ok { - t.Fatalf("expected packet wasn't written out") - } - - // Pack the output packet into a single buffer.View as the checkers - // assume that. - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { - t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) - } - - ipHdr := header.IPv6(pkt) - checker.IPv6(t, ipHdr, checker.ICMPv6( - checker.ICMPv6Type(test.ICMPType), - checker.ICMPv6Code(test.ICMPCode))) - - // We know we are looking at no extension headers in the error ICMP - // packets. - icmpPkt := header.ICMPv6(ipHdr.Payload()) - // We know we sent small packets that won't be truncated when reflected - // back to us. - originalPacket := icmpPkt.Payload() - if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { - t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) - } - if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { - t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) - } - return - } - - // Expect a UDP packet. - if got := stats.Value(); got != 1 { - t.Errorf("got UDP Rx Packets = %d, want = 1", got) - } - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Read: %s", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(udpPayload), - Total: len(udpPayload), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - - // Should not have any more UDP packets. - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// fragmentData holds the IPv6 payload for a fragmented IPv6 packet. -type fragmentData struct { - srcAddr tcpip.Address - dstAddr tcpip.Address - nextHdr uint8 - data buffer.VectorisedView -} - -func TestReceiveIPv6Fragments(t *testing.T) { - const ( - udpPayload1Length = 256 - udpPayload2Length = 128 - // Used to test cases where the fragment blocks are not a multiple of - // the fragment block size of 8 (RFC 8200 section 4.5). - udpPayload3Length = 127 - udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize - udpMaximumSizeMinus15 = header.UDPMaximumSize - 15 - fragmentExtHdrLen = 8 - // Note, not all routing extension headers will be 8 bytes but this test - // uses 8 byte routing extension headers for most sub tests. - routingExtHdrLen = 8 - ) - - udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View { - payloadLen := len(payload) - for i := 0; i < payloadLen; i++ { - payload[i] = uint8(i) * multiplier - } - - udpLength := header.UDPMinimumSize + payloadLen - - hdr := buffer.NewPrependable(udpLength) - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) - sum = header.Checksum(payload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - return hdr.View() - } - - var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte - udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:] - ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2) - - var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte - udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:] - ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2) - - var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte - udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:] - ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2) - - var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte - udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] - ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) - - var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte - udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:] - ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2) - - tests := []struct { - name string - expectedPayload []byte - fragments []fragmentData - expectedPayloads [][]byte - }{ - { - name: "No fragmentation", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: uint8(header.UDPProtocolNumber), - data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Atomic fragment", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Atomic fragment with size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), - - ipv6Payload3Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments out of order", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with different Next Header values", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - // NextHeader value is different than the one in the first fragment, so - // this NextHeader should be ignored. - buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with last fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments with first fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+63, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[:63], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[63:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with different IDs", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - udpMaximumSizeMinus15 >> 8, - udpMaximumSizeMinus15 & 0xff, - 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, - }, - { - name: "Two fragments with MF flag reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - udpMaximumSizeMinus15 >> 8, - (udpMaximumSizeMinus15 & 0xff) + 1, - 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with per-fragment routing header with zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with per-fragment routing header with non-zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header. - // - // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with routing header with non-zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header. - // - // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with zero segments left across fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is fragmentExtHdrLen+8 because the - // first 8 bytes of the 16 byte routing extension header is in - // this fragment. - fragmentExtHdrLen+8, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header (part 1) - // - // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of - // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), - - // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with non-zero segments left across fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is fragmentExtHdrLen+8 because the - // first 8 bytes of the 16 byte routing extension header is in - // this fragment. - fragmentExtHdrLen+8, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header (part 1) - // - // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of - // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), - - // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: nil, - }, - // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal" - // fragmented traffic. - { - name: "Two fragments with atomic", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - // This fragment has the same ID as the other fragments but is an atomic - // fragment. It should not interfere with the other fragments. - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), - - ipv6Payload2Addr1ToAddr2, - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), - - ipv6Payload2Addr1ToAddr2[:32], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 4, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), - - ipv6Payload2Addr1ToAddr2[32:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets from different sources but with same ID", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr3, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr3ToAddr2[:32], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr3, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 4, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), - - ipv6Payload1Addr3ToAddr2[32:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) - - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - // We're lying about transport protocol here so that we can generate - // raw extension headers for the tests. - TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) - - vv := hdr.View().ToVectorisedView() - vv.Append(f.data) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { - t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) - } - - for i, p := range test.expectedPayloads { - var buf bytes.Buffer - _, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) Read: %s", i, err) - } - if diff := cmp.Diff(p, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) - } - } - - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -func TestInvalidIPv6Fragments(t *testing.T) { - const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - nicID = 1 - hoplimit = 255 - ident = 1 - data = "TEST_INVALID_IPV6_FRAGMENTS" - ) - - type fragmentData struct { - ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6SerializableFragmentExtHdr - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - expectICMP bool - expectICMPType header.ICMPv6Type - expectICMPCode header.ICMPv6Code - expectICMPTypeSpecific uint32 - }{ - { - name: "fragment size is not a multiple of 8 and the M flag is true", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 9, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0 >> 3, - M: true, - Identification: ident, - }, - payload: []byte(data)[:9], - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - expectICMP: true, - expectICMPType: header.ICMPv6ParamProblem, - expectICMPCode: header.ICMPv6ErroneousHeader, - expectICMPTypeSpecific: header.IPv6PayloadLenOffset, - }, - { - name: "fragments reassembled into a payload exceeding the max IPv6 payload size", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, - M: false, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - expectICMP: true, - expectICMPType: header.ICMPv6ParamProblem, - expectICMPCode: header.ICMPv6ErroneousHeader, - expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - NewProtocol, - }, - }) - e := channel.New(1, 1500, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }}) - - var expectICMPPayload buffer.View - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - encodeArgs := f.ipv6Fields - encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) - ip.Encode(&encodeArgs) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if test.expectICMP { - expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(ProtocolNumber, pkt) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { - t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { - t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) - } - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - - checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), - checker.ICMPv6( - checker.ICMPv6Type(test.expectICMPType), - checker.ICMPv6Code(test.expectICMPCode), - checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), - checker.ICMPv6Payload([]byte(expectICMPPayload)), - ), - ) - }) - } -} - -func TestFragmentReassemblyTimeout(t *testing.T) { - const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - nicID = 1 - hoplimit = 255 - ident = 1 - data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" - ) - - type fragmentData struct { - ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6SerializableFragmentExtHdr - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - expectICMP bool - }{ - { - name: "first fragment only", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "two first fragments", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "second fragment only", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: false, - }, - { - name: "two fragments with a gap", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: true, - }, - { - name: "two fragments with a gap in reverse order", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - NewProtocol, - }, - Clock: clock, - }) - - e := channel.New(1, 1500, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }}) - - var firstFragmentSent buffer.View - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - encodeArgs := f.ipv6Fields - encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) - ip.Encode(&encodeArgs) - - fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { - firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(ProtocolNumber, pkt) - } - - clock.Advance(ReassembleTimeout) - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - if firstFragmentSent == nil { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - - checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), - checker.ICMPv6Payload([]byte(firstFragmentSent)), - ), - ) - }) - } -} - -func TestWriteStats(t *testing.T) { - const nPackets = 3 - tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int - }{ - { - name: "Accept all", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, - }, { - name: "Accept all with error", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, - }, { - name: "Drop all", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, - }, { - name: "Drop some", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule that matches only 1 - // of the 3 packets. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, - }, - } - - writers := []struct { - name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) - }{ - { - name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - nWritten := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { - return nWritten, err - } - nWritten++ - } - return nWritten, nil - }, - }, { - name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) - }, - }, - } - - for _, writer := range writers { - t.Run(writer.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) - rt := buildRoute(t, ep) - var pkts stack.PacketBufferList - for i := 0; i < nPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(header.UDPMinimumSize) - pkts.PushBack(pkt) - } - - test.setup(t, rt.Stack()) - - nWritten, _ := writer.writePackets(rt, pkts) - - if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) - } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) - } - if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) - } - }) - } - }) - } -} - -func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC(1, _) failed: %s", err) - } - const ( - src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ) - if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) - } - { - mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") - subnet, err := tcpip.NewSubnet(dst, mask) - if err != nil { - t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err) - } - return rt -} - -// limitedMatcher is an iptables matcher that matches after a certain number of -// packets are checked against it. -type limitedMatcher struct { - limit int -} - -// Name implements Matcher.Name. -func (*limitedMatcher) Name() string { - return "limitedMatcher" -} - -// Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { - if lm.limit == 0 { - return true, false - } - lm.limit-- - return false, false -} - -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEndpointAfterClose { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - -type fragmentInfo struct { - offset uint16 - more bool - payloadSize uint16 -} - -var fragmentationTests = []struct { - description string - mtu uint32 - gso *stack.GSO - transHdrLen int - payloadSize int - wantFragments []fragmentInfo -}{ - { - description: "No fragmentation", - mtu: header.IPv6MinimumMTU, - gso: nil, - transHdrLen: 0, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1000, more: false}, - }, - }, - { - description: "Fragmented", - mtu: header.IPv6MinimumMTU, - gso: nil, - transHdrLen: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 776, more: false}, - }, - }, - { - description: "Fragmented with mtu not a multiple of 8", - mtu: header.IPv6MinimumMTU + 1, - gso: nil, - transHdrLen: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 776, more: false}, - }, - }, - { - description: "No fragmentation with big header", - mtu: 2000, - gso: nil, - transHdrLen: 100, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1100, more: false}, - }, - }, - { - description: "Fragmented with gso none", - mtu: header.IPv6MinimumMTU, - gso: &stack.GSO{Type: stack.GSONone}, - transHdrLen: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 176, more: false}, - }, - }, - { - description: "Fragmented with big header", - mtu: header.IPv6MinimumMTU, - gso: nil, - transHdrLen: 100, - payloadSize: 1200, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 76, more: false}, - }, - }, -} - -func TestFragmentationWritePacket(t *testing.T) { - const ( - ttl = 42 - tos = stack.DefaultTOS - transportProto = tcp.ProtocolNumber - ) - - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - source := pkt.Clone() - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if err != nil { - t.Fatalf("WritePacket(_, _, _): = %s", err) - } - if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } -} - -func TestFragmentationWritePackets(t *testing.T) { - const ttl = 42 - tests := []struct { - description string - insertBefore int - insertAfter int - }{ - { - description: "Single packet", - insertBefore: 0, - insertAfter: 0, - }, - { - description: "With packet before", - insertBefore: 1, - insertAfter: 0, - }, - { - description: "With packet after", - insertBefore: 0, - insertAfter: 1, - }, - { - description: "With packet before and after", - insertBefore: 1, - insertAfter: 1, - }, - } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) - - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - var pkts stack.PacketBufferList - for i := 0; i < test.insertBefore; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - source := pkt - pkts.PushBack(pkt.Clone()) - for i := 0; i < test.insertAfter; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - - wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }) - if n != wantTotalPackets || err != nil { - t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets) - } - if got := len(ep.WrittenPackets); got != wantTotalPackets { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - - if wantTotalPackets == 0 { - return - } - - fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } - }) - } -} - -// TestFragmentationErrors checks that errors are returned from WritePacket -// correctly. -func TestFragmentationErrors(t *testing.T) { - const ttl = 42 - - tests := []struct { - description string - mtu uint32 - transHdrLen int - payloadSize int - allowPackets int - outgoingErrors int - mockError tcpip.Error - wantError tcpip.Error - }{ - { - description: "No frag", - mtu: 2000, - payloadSize: 1000, - transHdrLen: 0, - allowPackets: 0, - outgoingErrors: 1, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag", - mtu: 1300, - payloadSize: 3000, - transHdrLen: 0, - allowPackets: 0, - outgoingErrors: 3, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on second frag", - mtu: 1500, - payloadSize: 4000, - transHdrLen: 0, - allowPackets: 1, - outgoingErrors: 2, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error when MTU is smaller than transport header", - mtu: header.IPv6MinimumMTU, - transHdrLen: 1500, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrMessageTooLong{}, - }, - { - description: "Error when MTU is smaller than IPv6 minimum MTU", - mtu: header.IPv6MinimumMTU - 1, - transHdrLen: 0, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrInvalidEndpointState{}, - }, - } - - for _, ft := range tests { - t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) - r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if diff := cmp.Diff(ft.wantError, err); diff != "" { - t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { - t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) - } - }) - } -} - -func TestForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - randomSequence = 123 - randomIdent = 42 - ) - - ipv6Addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10::1").To16()), - PrefixLen: 64, - } - ipv6Addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11::1").To16()), - PrefixLen: 64, - } - remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) - remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) - - tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - }{ - { - name: "TTL of zero", - TTL: 0, - expectErrorICMP: true, - }, - { - name: "TTL of one", - TTL: 1, - expectErrorICMP: true, - }, - { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, - }, - { - name: "TTL of three", - TTL: 3, - expectErrorICMP: false, - }, - { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} - if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} - if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: ipv6Addr1.Subnet(), - NIC: nicID1, - }, - { - Destination: ipv6Addr2.Subnet(), - NIC: nicID2, - }, - }) - - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) - } - - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) - icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv6EchoRequest) - icmp.SetCode(header.ICMPv6UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, - }) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e1.InjectInbound(ProtocolNumber, requestPkt) - - if test.expectErrorICMP { - reply, ok := e1.Read() - if !ok { - t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") - } - - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv6Addr1.Address), - checker.DstAddr(remoteIPv6Addr1), - checker.TTL(DefaultTTL), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), - checker.ICMPv6Payload([]byte(hdr.View())), - ), - ) - - if n := e2.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) - } - } else { - reply, ok := e2.Read() - if !ok { - t.Fatal("expected ICMP Echo Request packet through outgoing NIC") - } - - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv6Addr1), - checker.DstAddr(remoteIPv6Addr2), - checker.TTL(test.TTL-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest), - checker.ICMPv6Code(header.ICMPv6UnusedCode), - checker.ICMPv6Payload(nil), - ), - ) - - if n := e1.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } - } - }) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // supposed to be bound. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go deleted file mode 100644 index fe39555e0..000000000 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ /dev/null @@ -1,297 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" -) - -var ( - linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) - globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) -) - -func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { - t.Helper() - - checker.IPv6WithExtHdr(t, p, - checker.IPv6ExtHdr( - checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), - ), - checker.SrcAddr(localAddress), - checker.DstAddr(remoteAddress), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(mldType, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(groupAddress), - ), - ) -} - -func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // The stack will join an address's solicited node multicast address when - // an address is added. An MLD report message should be sent for the - // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) - } - - // The stack will leave an address's solicited node multicast address when - // an address is removed. An MLD done message should be sent for the - // solicited-node group. - if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a done message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) - } -} - -func TestSendQueuedMLDReports(t *testing.T) { - const ( - nicID = 1 - maxReports = 2 - ) - - tests := []struct { - name string - dadTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD Disabled", - dadTransmits: 0, - retransmitTimer: 0, - }, - { - name: "DAD Enabled", - dadTransmits: 1, - retransmitTimer: time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dadTransmits, - RetransmitTimer: test.retransmitTimer, - }, - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - Clock: clock, - }) - - // Allow space for an extra packet so we can observe packets that were - // unexpectedly sent. - e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - resolveDAD := func(addr, snmc tcpip.Address) { - clock.Advance(dadResolutionTime) - if p, ok := e.Read(); !ok { - t.Fatal("expected DAD packet") - } else { - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr), - checker.NDPNSOptions(nil), - )) - } - } - - var reportCounter uint64 - reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - var doneCounter uint64 - doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - if got := doneStat.Value(); got != doneCounter { - t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) - } - - // Joining a group without an assigned address should send an MLD report - // with the unspecified address. - if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) - } - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", globalMulticastAddr) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Adding a global address should not send reports for the already joined - // group since we should only send queued reports when a link-local - // addres sis assigned. - // - // Note, we will still expect to send a report for the global address's - // solicited node address from the unspecified address as per RFC 3590 - // section 4. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) - } - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", globalAddrSNMC) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) - } - if dadResolutionTime != 0 { - // Reports should not be sent when the address resolves. - resolveDAD(globalAddr, globalAddrSNMC) - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - } - // Leave the group since we don't care about the global address's - // solicited node multicast group membership. - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) - } - if got := doneStat.Value(); got != doneCounter { - t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) - } - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Adding a link-local address should send a report for its solicited node - // address and globalMulticastAddr. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) - } - if dadResolutionTime != 0 { - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) - } - resolveDAD(linkLocalAddr, linkLocalAddrSNMC) - } - - // We expect two batches of reports to be sent (1 batch when the - // link-local address is assigned, and another after the maximum - // unsolicited report interval. - for i := 0; i < 2; i++ { - // We expect reports to be sent (one for globalMulticastAddr and another - // for linkLocalAddrSNMC). - reportCounter += maxReports - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - - addrs := map[tcpip.Address]bool{ - globalMulticastAddr: false, - linkLocalAddrSNMC: false, - } - for range addrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) - } - - addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() - if seen, ok := addrs[addr]; !ok { - t.Fatalf("got unexpected packet destined to %s", addr) - } else if seen { - t.Fatalf("got another packet destined to %s", addr) - } - - addrs[addr] = true - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) - - clock.Advance(ipv6.UnsolicitedReportIntervalMax) - } - } - - // Should not send any more reports. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go deleted file mode 100644 index ce20af0e3..000000000 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ /dev/null @@ -1,1332 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "context" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - -var _ NDPDispatcher = (*testNDPDispatcher)(nil) - -// testNDPDispatcher is an NDPDispatcher only allows default router discovery. -type testNDPDispatcher struct { - addr tcpip.Address -} - -func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { -} - -func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { - t.addr = addr - return true -} - -func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { - t.addr = addr -} - -func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false -} - -func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { -} - -func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return false -} - -func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) { -} - -func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) { -} - -func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) { -} - -func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { - var ndpDisp testNDPDispatcher - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) - } - - ipv6EP := ep.(*endpoint) - ipv6EP.mu.Lock() - ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) - ipv6EP.mu.Unlock() - - if ndpDisp.addr != lladdr1 { - t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) - } - - ndpDisp.addr = "" - ndpEP := ep.(stack.NDPEndpoint) - ndpEP.InvalidateDefaultRouter(lladdr1) - if ndpDisp.addr != lladdr1 { - t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) - } -} - -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - -// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a -// valid NDP NS message with the Source Link Layer Address option results in a -// new entry in the link address cache for the sender of the message. -func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(lladdr0) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - neighbors, err := s.Neighbors(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) - } - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) - } - neighborByAddr[n.Addr] = n - } - - if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 { - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - if !ok { - t.Fatalf("expected a neighbor entry for %q", lladdr1) - } - if neigh.LinkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr) - } - if neigh.State != stack.Stale { - t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale) - } - } else { - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - - if ok { - t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) - } - } - }) - } -} - -func TestNeighborSolicitationResponse(t *testing.T) { - const nicID = 1 - nicAddr := lladdr0 - remoteAddr := lladdr1 - nicAddrSNMC := header.SolicitedNodeAddr(nicAddr) - nicLinkAddr := linkAddr0 - remoteLinkAddr0 := linkAddr1 - remoteLinkAddr1 := linkAddr2 - - tests := []struct { - name string - nsOpts header.NDPOptionsSerializer - nsSrcLinkAddr tcpip.LinkAddress - nsSrc tcpip.Address - nsDst tcpip.Address - nsInvalid bool - naDstLinkAddr tcpip.LinkAddress - naSolicited bool - naSrc tcpip.Address - naDst tcpip.Address - performsLinkResolution bool - }{ - { - name: "Unspecified source to solicited-node multicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - naSolicited: false, - naSrc: nicAddr, - naDst: header.IPv6AllNodesMulticastAddress, - }, - { - name: "Unspecified source with source ll option to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - { - name: "Unspecified source to unicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddr, - nsInvalid: true, - }, - { - name: "Unspecified source with source ll option to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddr, - nsInvalid: true, - }, - { - name: "Specified source with 1 source ll to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 1 source ll different from route to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr1, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source to multicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - { - name: "Specified source with 2 source ll to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - - { - name: "Specified source to unicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - // Since we send a unicast solicitations to a node without an entry for - // the remote, the node needs to perform neighbor discovery to get the - // remote's link address to send the advertisement response. - performsLinkResolution: true, - }, - { - name: "Specified source with 1 source ll to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 1 source ll different from route to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr1, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 2 source ll to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(1, 1280, nicLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } - - // If we expected the NS to be invalid, we have nothing else to check. - return - } - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NS response") - } - - respNSDst := header.SolicitedNodeAddr(test.nsSrc) - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) - if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(nicAddr), - checker.DstAddr(respNSDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(test.nsSrc), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(nicLinkAddr), - }), - )) - - ser := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(test.nsSrc) - na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, - }) - e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NA response") - } - - if p.Route.LocalAddress != test.naSrc { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } - if p.Route.RemoteAddress != test.naDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) - } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) - }) - } -} - -// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a -// valid NDP NA message with the Target Link Layer Address option does not -// result in a new entry in the neighbor cache for the target of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - isValid bool - }{ - { - name: "Valid", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, - isValid: true, - }, - { - name: "Too Small", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, - }, - { - name: "Multiple", - optsBuf: []byte{ - 2, 1, 2, 3, 4, 5, 6, 7, - 2, 1, 2, 3, 4, 5, 6, 8, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.MessageBody()) - ns.SetTargetAddress(lladdr1) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - neighbors, err := s.Neighbors(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) - } - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) - } - neighborByAddr[n.Addr] = n - } - - if neigh, ok := neighborByAddr[lladdr1]; ok { - t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) - } - - if test.isValid { - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - return s, ep - } - - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - var extHdrs header.IPv6ExtHdrSerializer - if atomicFragment { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) - } - extHdrsLen := extHdrs.Length() - - ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + extHdrsLen), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - ExtensionHeaders: extHdrs, - }) - vv := ip.ToVectorisedView() - vv.AppendView(payload) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - var sllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - extraData: sllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code header.ICMPv6Code - valid bool - }{ - { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, - }, - { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, - }, - } - - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } - - t.Run(name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) - - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - // RouterOnlyPacketsReceivedByHost count should initially be 0. - if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) - - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } - - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) - } - - want = 0 - if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. - want = 1 - } - if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) - } - - }) - } - }) - } - } -} - -// TestNeighborAdvertisementValidation tests that the NIC validates received -// Neighbor Advertisements. -// -// In particular, if the IP Destination Address is a multicast address, and the -// Solicited flag is not zero, the Neighbor Advertisement is invalid and should -// be discarded. -func TestNeighborAdvertisementValidation(t *testing.T) { - tests := []struct { - name string - ipDstAddr tcpip.Address - solicitedFlag bool - valid bool - }{ - { - name: "Multicast IP destination address with Solicited flag set", - ipDstAddr: header.IPv6AllNodesMulticastAddress, - solicitedFlag: true, - valid: false, - }, - { - name: "Multicast IP destination address with Solicited flag unset", - ipDstAddr: header.IPv6AllNodesMulticastAddress, - solicitedFlag: false, - valid: true, - }, - { - name: "Unicast IP destination address with Solicited flag set", - ipDstAddr: lladdr0, - solicitedFlag: true, - valid: true, - }, - { - name: "Unicast IP destination address with Solicited flag unset", - ipDstAddr: lladdr0, - solicitedFlag: false, - valid: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetTargetAddress(lladdr1) - na.SetSolicitedFlag(test.solicitedFlag) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: test.ipDstAddr, - }) - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxNA := stats.NeighborAdvert - - if got := rxNA.Value(); got != 0 { - t.Fatalf("got rxNA = %d, want = 0", got) - } - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if got := rxNA.Value(); got != 1 { - t.Fatalf("got rxNA = %d, want = 1", got) - } - var wantInvalid uint64 = 1 - if test.valid { - wantInvalid = 0 - } - if got := invalid.Value(); got != wantInvalid { - t.Fatalf("got invalid = %d, want = %d", got, wantInvalid) - } - // As per RFC 4861 section 7.2.5: - // When a valid Neighbor Advertisement is received ... - // If no entry exists, the advertisement SHOULD be silently discarded. - // There is no need to create an entry if none exists, since the - // recipient has apparently not initiated any communication with the - // target. - if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) - } - }) - } -} - -// TestRouterAdvertValidation tests that when the NIC is configured to handle -// NDP Router Advertisement packets, it validates the Router Advertisement -// properly before handling them. -func TestRouterAdvertValidation(t *testing.T) { - tests := []struct { - name string - src tcpip.Address - hopLimit uint8 - code header.ICMPv6Code - ndpPayload []byte - expectedSuccess bool - }{ - { - "OK", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - true, - }, - { - "NonLinkLocalSourceAddr", - addr1, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "HopLimitNot255", - lladdr0, - 254, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NonZeroCode", - lladdr0, - 255, - 1, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NDPPayloadTooSmall", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, - }, - false, - }, - { - "OKWithOptions", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - 2, 1, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - true, - }, - { - "OptionWithZeroLength", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - // Invalid as it has 0 length. - 2, 0, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.MessageBody(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } - - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -// TestCheckDuplicateAddress checks that calls to CheckDuplicateAddress and DAD -// performed when adding new addresses do not interfere with each other. -func TestCheckDuplicateAddress(t *testing.T) { - const nicID = 1 - - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ - DADConfigs: dadConfigs, - })}, - }) - // This test is expected to send at max 2 DAD messages. We allow an extra - // packet to be stored to catch unexpected packets. - e := channel.New(3, header.IPv6MinimumMTU, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - dadPacketsSent := 1 - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - // Start DAD for the address we just added. - // - // Even though the stack will perform DAD before the added address transitions - // from tentative to assigned, this DAD request should be independent of that. - ch := make(chan stack.DADResult, 3) - dadRequestsMade := 1 - dadPacketsSent++ - if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) - } else if res != stack.DADStarting { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting) - } - - // Remove the address and make sure our DAD request was not stopped. - if err := s.RemoveAddress(nicID, lladdr0); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err) - } - // Should not restart DAD since we already requested DAD above - the handler - // should be called when the original request compeletes so we should not send - // an extra DAD message here. - dadRequestsMade++ - if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) - } else if res != stack.DADAlreadyRunning { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning) - } - - // Wait for DAD to resolve. - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - for i := 0; i < dadRequestsMade; i++ { - if diff := cmp.Diff(stack.DADResult{Resolved: true}, <-ch); diff != "" { - t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) - } - } - // Should have no more results. - select { - case r := <-ch: - t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - - snmc := header.SolicitedNodeAddr(lladdr0) - remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) - - for i := 0; i < dadPacketsSent; i++ { - p, ok := e.Read() - if !ok { - t.Fatalf("expected %d-th DAD message", i) - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber) - } - - if p.Route.RemoteLinkAddress != remoteLinkAddr { - t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions(nil), - )) - } - - // Should have no more packets. - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } -} diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go deleted file mode 100644 index 03b7ecffb..000000000 --- a/pkg/tcpip/network/multicast_group_test.go +++ /dev/null @@ -1,1265 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "fmt" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - - stackIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") - defaultIPv4PrefixLength = 24 - ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - - ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") - ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") - ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") - ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") - ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") - - igmpMembershipQuery = uint8(header.IGMPMembershipQuery) - igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) - igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) - igmpLeaveGroup = uint8(header.IGMPLeaveGroup) - mldQuery = uint8(header.ICMPv6MulticastListenerQuery) - mldReport = uint8(header.ICMPv6MulticastListenerReport) - mldDone = uint8(header.ICMPv6MulticastListenerDone) - - maxUnsolicitedReports = 2 -) - -var ( - // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the - // NIC will wait before sending an unsolicited report after joining a - // multicast group, in deciseconds. - unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { - const decisecond = time.Second / 10 - if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { - panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) - } - return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) - }() - - ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) -) - -// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet -// sent to the provided address with the passed fields set. -func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv6WithExtHdr(t, payload, - checker.IPv6ExtHdr( - checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), - ), - checker.SrcAddr(ipv6Addr), - checker.DstAddr(remoteAddress), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, - checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), - checker.MLDMulticastAddress(groupAddress), - ), - ) -} - -// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet -// sent to the provided address with the passed fields set. -func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv4(t, payload, - checker.SrcAddr(stackIPv4Addr), - checker.DstAddr(remoteAddress), - // TTL for an IGMP message must be 1 as per RFC 2236 section 2. - checker.TTL(1), - checker.IPv4RouterAlert(), - checker.IGMP( - checker.IGMPType(header.IGMPType(igmpType)), - checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), - checker.IGMPGroupAddress(groupAddress), - ), - ) -} - -func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - - e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) - s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) - return e, s, clock -} - -func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { - t.Helper() - - igmpEnabled := v4 && mgpEnabled - mldEnabled := !v4 && mgpEnabled - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMP: ipv4.IGMPOptions{ - Enabled: igmpEnabled, - }, - }), - ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: mldEnabled, - }, - }), - }, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - addr := tcpip.AddressWithPrefix{ - Address: stackIPv4Addr, - PrefixLen: defaultIPv4PrefixLength, - } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) - } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) - } - - return s, clock -} - -// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join -// when it is created with an IPv6 address. -// -// To not interfere with tests, checkInitialIPv6Groups will leave the added -// address's solicited node multicast group so that the tests can all assume -// the NIC has not joined any IPv6 groups. -func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { - t.Helper() - - stats := s.Stats().ICMP.V6.PacketsSent - - reportCounter++ - if got := stats.MulticastListenerReport.Value(); got != reportCounter { - t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) - } - - // Leave the group to not affect the tests. This is fine since we are not - // testing DAD or the solicited node address specifically. - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) - } - leaveCounter++ - if got := stats.MulticastListenerDone.Value(); got != leaveCounter { - t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - - return reportCounter, leaveCounter -} - -// createAndInjectIGMPPacket creates and injects an IGMP packet with the -// specified fields. -func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { - options := header.IPv4OptionsSerializer{ - &header.IPv4SerializableRouterAlertOption{}, - } - buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: header.IGMPTTL, - Protocol: uint8(header.IGMPProtocolNumber), - SrcAddr: remoteIPv4Addr, - DstAddr: header.IPv4AllSystems, - Options: options, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - igmp := header.IGMP(ip.Payload()) - igmp.SetType(header.IGMPType(igmpType)) - igmp.SetMaxRespTime(maxRespTime) - igmp.SetGroupAddress(groupAddress) - igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - - e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -// createAndInjectMLDPacket creates and injects an MLD packet with the -// specified fields. -// -// Note, the router alert option is not included in this packet. -// -// TODO(b/162198658): set the router alert option. -func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { - icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize - buf := buffer.NewView(header.IPv6MinimumSize + icmpSize) - - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - HopLimit: header.MLDHopLimit, - TransportProtocol: header.ICMPv6ProtocolNumber, - SrcAddr: header.IPv4Any, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - icmp := header.ICMPv6(buf[header.IPv6MinimumSize:]) - icmp.SetType(header.ICMPv6Type(mldType)) - mld := header.MLD(icmp.MessageBody()) - mld.SetMaximumResponseDelay(uint16(maxRespDelay)) - mld.SetMulticastAddress(groupAddress) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - - e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -// TestMGPDisabled tests that the multicast group protocol is not enabled by -// default. -func TestMGPDisabled(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - rxQuery func(*channel.Endpoint) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxQuery: func(e *channel.Endpoint) { - createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxQuery: func(e *channel.Endpoint) { - createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) - - // This NIC may join multicast groups when it is enabled but since MGP is - // disabled, no reports should be sent. - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) - } - - // Test joining a specific group explicitly and verify that no reports are - // sent. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) - } - - // Inject a general query message. This should only trigger a report to be - // sent if the MGP was enabled. - test.rxQuery(e) - if got := test.receivedQueryStat(s).Value(); got != 1 { - t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) - } - }) - } -} - -func TestMGPReceiveCounters(t *testing.T) { - tests := []struct { - name string - headerType uint8 - maxRespTime byte - groupAddress tcpip.Address - statCounter func(*stack.Stack) *tcpip.StatCounter - rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) - }{ - { - name: "IGMP Membership Query", - headerType: igmpMembershipQuery, - maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, - groupAddress: header.IPv4Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMPv1 Membership Report", - headerType: igmpv1MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.V1MembershipReport - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMPv2 Membership Report", - headerType: igmpv2MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.V2MembershipReport - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMP Leave Group", - headerType: igmpLeaveGroup, - maxRespTime: 0, - groupAddress: header.IPv4AllRoutersGroup, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.LeaveGroup - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "MLD Query", - headerType: mldQuery, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxMGPkt: createAndInjectMLDPacket, - }, - { - name: "MLD Report", - headerType: mldReport, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport - }, - rxMGPkt: createAndInjectMLDPacket, - }, - { - name: "MLD Done", - headerType: mldDone, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone - }, - rxMGPkt: createAndInjectMLDPacket, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) - - test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) - if got := test.statCounter(s).Value(); got != 1 { - t.Fatalf("got %s received = %d, want = 1", test.name, got) - } - }) - } -} - -// TestMGPJoinGroup tests that when explicitly joining a multicast group, the -// stack schedules and sends correct Membership Reports. -func TestMGPJoinGroup(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, _ = test.checkInitialGroups(t, e, s, clock) - } - - // Test joining a specific address explicitly and verify a Report is sent - // immediately. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - reportCounter++ - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Verify the second report is sent by the maximum unsolicited response - // interval. - p, ok := e.Read() - if ok { - t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) - } - clock.Advance(test.maxUnsolicitedResponseDelay) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPLeaveGroup tests that when leaving a previously joined multicast -// group the stack sends a leave/done message. -func TestMGPLeaveGroup(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - validateLeave func(*testing.T, channel.PacketInfo) - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - reportCounter++ - if got := test.sentReportStat(s).Value(); got != reportCounter { - t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Leaving the group should trigger an leave/done message to be sent. - if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) - } - leaveCounter++ - if got := test.sentLeaveStat(s).Value(); got != leaveCounter { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a leave message to be sent") - } else { - test.validateLeave(t, p) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPQueryMessages tests that a report is sent in response to query -// messages. -func TestMGPQueryMessages(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - rxQuery func(*channel.Endpoint, uint8, tcpip.Address) - validateReport func(*testing.T, channel.PacketInfo) - maxRespTimeToDuration func(uint8) time.Duration - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { - createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - maxRespTimeToDuration: header.DecisecondToDuration, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { - createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - maxRespTimeToDuration: func(d uint8) time.Duration { - return time.Duration(d) * time.Millisecond - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - subTests := []struct { - name string - multicastAddr tcpip.Address - expectReport bool - }{ - { - name: "Unspecified", - multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), - expectReport: true, - }, - { - name: "Specified", - multicastAddr: test.multicastAddr, - expectReport: true, - }, - { - name: "Specified other address", - multicastAddr: func() tcpip.Address { - addrBytes := []byte(test.multicastAddr) - addrBytes[len(addrBytes)-1]++ - return tcpip.Address(addrBytes) - }(), - expectReport: false, - }, - } - - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, _ = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - sentReportStat := test.sentReportStat(s) - for i := 0; i < maxUnsolicitedReports; i++ { - sentReportStat := test.sentReportStat(s) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatalf("expected %d-th report message to be sent", i) - } else { - test.validateReport(t, p) - } - clock.Advance(test.maxUnsolicitedResponseDelay) - } - if t.Failed() { - t.FailNow() - } - - // Should not send any more packets until a query. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - - // Receive a query message which should trigger a report to be sent at - // some time before the maximum response time if the report is - // targeted at the host. - const maxRespTime = 100 - test.rxQuery(e, maxRespTime, subTest.multicastAddr) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p.Pkt) - } - - if subTest.expectReport { - clock.Advance(test.maxRespTimeToDuration(maxRespTime)) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } - }) - } -} - -// TestMGPQueryMessages tests that no further reports or leave/done messages -// are sent after receiving a report. -func TestMGPReportMessages(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - rxReport func(*channel.Endpoint) - validateReport func(*testing.T, channel.PacketInfo) - maxRespTimeToDuration func(uint8) time.Duration - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - rxReport: func(e *channel.Endpoint) { - createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - maxRespTimeToDuration: header.DecisecondToDuration, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - rxReport: func(e *channel.Endpoint) { - createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - maxRespTimeToDuration: func(d uint8) time.Duration { - return time.Duration(d) * time.Millisecond - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - sentReportStat := test.sentReportStat(s) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Receiving a report for a group we joined should cancel any further - // reports. - test.rxReport(e) - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); ok { - t.Errorf("sent unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Leaving a group after getting a report should not send a leave/done - // message. - if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) - } - clock.Advance(time.Hour) - if got := test.sentLeaveStat(s).Value(); got != leaveCounter { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -func TestMGPWithNICLifecycle(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddrs []tcpip.Address - finalMulticastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) - validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) - getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, - finalMulticastAddr: ipv4MulticastAddr3, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) - }, - getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { - t.Helper() - - ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { - t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) - } - addr := header.IGMP(ipv4.Payload()).GroupAddress() - s, ok := seen[addr] - if !ok { - t.Fatalf("unexpectedly got a packet for group %s", addr) - } - if s { - t.Fatalf("already saw packet for group %s", addr) - } - seen[addr] = true - return addr - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, - finalMulticastAddr: ipv6MulticastAddr3, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateMLDPacket(t, p, addr, mldReport, 0, addr) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) - }, - getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { - t.Helper() - - ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - - ipv6HeaderIter := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), - buffer.View(ipv6.Payload()).ToVectorisedView(), - ) - - var transport header.IPv6RawPayloadHeader - for { - h, done, err := ipv6HeaderIter.Next() - if err != nil { - t.Fatalf("ipv6HeaderIter.Next(): %s", err) - } - if done { - t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) - } - if t, ok := h.(header.IPv6RawPayloadHeader); ok { - transport = t - break - } - } - - if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { - t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) - } - icmpv6 := header.ICMPv6(transport.Buf.ToView()) - if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { - t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) - } - addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() - s, ok := seen[addr] - if !ok { - t.Fatalf("unexpectedly got a packet for group %s", addr) - } - if s { - t.Fatalf("already saw packet for group %s", addr) - } - seen[addr] = true - return addr - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - sentReportStat := test.sentReportStat(s) - for _, a := range test.multicastAddrs { - if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) - } - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatalf("expected a report message to be sent for %s", a) - } else { - test.validateReport(t, p, a) - } - } - if t.Failed() { - t.FailNow() - } - - // Leave messages should be sent for the joined groups when the NIC is - // disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - sentLeaveStat := test.sentLeaveStat(s) - leaveCounter += uint64(len(test.multicastAddrs)) - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - { - seen := make(map[tcpip.Address]bool) - for _, a := range test.multicastAddrs { - seen[a] = false - } - - for i := range test.multicastAddrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected (%d-th) leave message to be sent", i) - } - - test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) - } - } - if t.Failed() { - t.FailNow() - } - - // Reports should be sent for the joined groups when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("EnableNIC(%d): %s", nicID, err) - } - reportCounter += uint64(len(test.multicastAddrs)) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - { - seen := make(map[tcpip.Address]bool) - for _, a := range test.multicastAddrs { - seen[a] = false - } - - for i := range test.multicastAddrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected (%d-th) report message to be sent", i) - } - - test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) - } - } - if t.Failed() { - t.FailNow() - } - - // Joining/leaving a group while disabled should not send any messages. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - leaveCounter += uint64(len(test.multicastAddrs)) - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - for i := range test.multicastAddrs { - if _, ok := e.Read(); !ok { - t.Fatalf("expected (%d-th) leave message to be sent", i) - } - } - for _, a := range test.multicastAddrs { - if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) - } - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); ok { - t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) - } - } - if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) - } - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); ok { - t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) - } - - // A report should only be sent for the group we last joined after - // enabling the NIC since the original groups were all left. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("EnableNIC(%d): %s", nicID, err) - } - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p, test.finalMulticastAddr) - } - - clock.Advance(test.maxUnsolicitedResponseDelay) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p, test.finalMulticastAddr) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPDisabledOnLoopback tests that the multicast group protocol is not -// performed on loopback interfaces since they have no neighbours. -func TestMGPDisabledOnLoopback(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) - - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - - // Test joining a specific group explicitly and verify that no reports are - // sent. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - }) - } -} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD deleted file mode 100644 index 57abec5c9..000000000 --- a/pkg/tcpip/ports/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ports", - srcs = ["ports.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - ], -) - -go_test( - name = "ports_test", - srcs = ["ports_test.go"], - library = ":ports", - deps = [ - "//pkg/tcpip", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go new file mode 100644 index 000000000..1e1d9cd4c --- /dev/null +++ b/pkg/tcpip/ports/ports_state_autogen.go @@ -0,0 +1,40 @@ +// automatically generated by stateify. + +package ports + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *Flags) StateTypeName() string { + return "pkg/tcpip/ports.Flags" +} + +func (f *Flags) StateFields() []string { + return []string{ + "MostRecent", + "LoadBalanced", + "TupleOnly", + } +} + +func (f *Flags) beforeSave() {} + +func (f *Flags) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.MostRecent) + stateSinkObject.Save(1, &f.LoadBalanced) + stateSinkObject.Save(2, &f.TupleOnly) +} + +func (f *Flags) afterLoad() {} + +func (f *Flags) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.MostRecent) + stateSourceObject.Load(1, &f.LoadBalanced) + stateSourceObject.Load(2, &f.TupleOnly) +} + +func init() { + state.Register((*Flags)(nil)) +} diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go deleted file mode 100644 index e70fbb72b..000000000 --- a/pkg/tcpip/ports/ports_test.go +++ /dev/null @@ -1,457 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ports - -import ( - "math/rand" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 - - fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") - fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") -) - -type portReserveTestAction struct { - port uint16 - ip tcpip.Address - want tcpip.Error - flags Flags - release bool - device tcpip.NICID - dest tcpip.FullAddress -} - -func TestPortReservation(t *testing.T) { - for _, test := range []struct { - tname string - actions []portReserveTestAction - }{ - { - tname: "bind to ip", - actions: []portReserveTestAction{ - {port: 80, ip: fakeIPAddress, want: nil}, - {port: 80, ip: fakeIPAddress1, want: nil}, - /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, - }, - }, - { - tname: "bind to inaddr any", - actions: []portReserveTestAction{ - {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - /* release fakeIPAddress, but anyIPAddress is still inuse */ - {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, - /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ - {port: 22, ip: anyIPAddress, want: nil, release: true}, - {port: 22, ip: fakeIPAddress, want: nil}, - }, - }, { - tname: "bind to zero port", - actions: []portReserveTestAction{ - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to ip with reuseport", - actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - - {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to inaddr any with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with device fails", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 3, want: nil}, - {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 1, want: nil}, - {port: 24, ip: fakeIPAddress, device: 2, want: nil}, - }, - }, { - tname: "bind to device and then without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, want: nil}, - {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "binding with reuseport and device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "mixing reuseport and not reuseport by binding to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: nil}, - }, - }, { - tname: "can't bind to 0 after mixing reuseport and not reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind and release", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - - // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with reuseport once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "release an unreserved device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, - // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - // Release all. - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, - }, - }, { - tname: "bind with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind twice with reuseaddr once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseport, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, - }, - }, { - tname: "bind tuple with reuseaddr, and then wildcard", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - }, - }, { - tname: "bind tuple with reuseaddr, and then wildcard", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind two tuples with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, - }, - }, { - tname: "bind two tuples", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, - }, - }, { - tname: "bind wildcard, and then tuple with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind wildcard twice with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, - }, - }, - } { - t.Run(test.tname, func(t *testing.T) { - pm := NewPortManager() - net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} - - for _, test := range test.actions { - if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) - continue - } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) - if diff := cmp.Diff(test.want, err); diff != "" { - t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) - } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) - } - } - }) - } -} - -func TestPickEphemeralPort(t *testing.T) { - for _, test := range []struct { - name string - f func(port uint16) (bool, tcpip.Error) - wantErr tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, tcpip.Error) { - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, tcpip.Error) { - return false, &tcpip.ErrBadBuffer{} - }, - wantErr: &tcpip.ErrBadBuffer{}, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: FirstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - port, err := pm.PickEphemeralPort(test.f) - if diff := cmp.Diff(test.wantErr, err); diff != "" { - t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) - } - if port != test.wantPort { - t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) - } - }) - } -} - -func TestPickEphemeralPortStable(t *testing.T) { - for _, test := range []struct { - name string - f func(port uint16) (bool, tcpip.Error) - wantErr tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, tcpip.Error) { - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, tcpip.Error) { - return false, &tcpip.ErrBadBuffer{} - }, - wantErr: &tcpip.ErrBadBuffer{}, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: FirstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) - port, err := pm.PickEphemeralPortStable(portOffset, test.f) - if diff := cmp.Diff(test.wantErr, err); diff != "" { - t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) - } - if port != test.wantPort { - t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) - } - }) - } -} diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD deleted file mode 100644 index db9b91815..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_connect", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go deleted file mode 100644 index 856ea998d..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and connects to a peer. Similar to "nc <address> <port>". While the -// sample is running, attempts to connect to its IPv4 address will result in -// a RST segment. -// -// As an example of how to run it, a TUN device can be created and enabled on -// a linux host as follows (this only needs to be done once per boot): -// -// [sudo] ip tuntap add user <username> mode tun <device-name> -// [sudo] ip link set <device-name> up -// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name> -// -// A concrete example: -// -// $ sudo ip tuntap add user wedsonaf mode tun tun0 -// $ sudo ip link set tun0 up -// $ sudo ip addr add 192.168.1.1/24 dev tun0 -// -// Then one can run tun_tcp_connect as such: -// -// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234 -// -// This will attempt to connect to the linux host's stack. One can run nc in -// listen mode to accept a connect from tun_tcp_connect and exchange data. -package main - -import ( - "bytes" - "fmt" - "log" - "math/rand" - "net" - "os" - "strconv" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// writer reads from standard input and writes to the endpoint until standard -// input is closed. It signals that it's done by closing the provided channel. -func writer(ch chan struct{}, ep tcpip.Endpoint) { - defer func() { - ep.Shutdown(tcpip.ShutdownWrite) - close(ch) - }() - - var b bytes.Buffer - if err := func() error { - for { - if _, err := b.ReadFrom(os.Stdin); err != nil { - return fmt.Errorf("b.ReadFrom failed: %w", err) - } - - for b.Len() != 0 { - if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil { - return fmt.Errorf("ep.Write failed: %s", err) - } - } - } - }(); err != nil { - fmt.Println(err) - } -} - -func main() { - if len(os.Args) != 6 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>") - } - - tunName := os.Args[1] - addrName := os.Args[2] - portName := os.Args[3] - remoteAddrName := os.Args[4] - remotePortName := os.Args[5] - - rand.Seed(time.Now().UnixNano()) - - addr := tcpip.Address(net.ParseIP(addrName).To4()) - remote := tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()), - } - - var localPort uint16 - if v, err := strconv.Atoi(portName); err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } else { - localPort = uint16(v) - } - - if v, err := strconv.Atoi(remotePortName); err != nil { - log.Fatalf("Unable to convert port %v: %v", remotePortName, err) - } else { - remote.Port = uint16(v) - } - - // Create the stack with ipv4 and tcp protocols, then add a tun-based - // NIC and ipv4 address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - fd, err := tun.Open(tunName) - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - }) - - // Create TCP endpoint. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if e != nil { - log.Fatal(e) - } - - // Bind if a port is specified. - if localPort != 0 { - if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil { - log.Fatal("Bind failed: ", err) - } - } - - // Issue connect request and wait for it to complete. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - terr := ep.Connect(remote) - if _, ok := terr.(*tcpip.ErrConnectStarted); ok { - fmt.Println("Connect is pending...") - <-notifyCh - terr = ep.LastError() - } - wq.EventUnregister(&waitEntry) - - if terr != nil { - log.Fatal("Unable to connect: ", terr) - } - - fmt.Println("Connected") - - // Start the writer in its own goroutine. - writerCompletedCh := make(chan struct{}) - go writer(writerCompletedCh, ep) // S/R-SAFE: sample code. - - // Read data and write to standard output until the peer closes the - // connection from its side. - wq.EventRegister(&waitEntry, waiter.EventIn) - for { - _, err := ep.Read(os.Stdout, tcpip.ReadOptions{}) - if err != nil { - if _, ok := err.(*tcpip.ErrClosedForReceive); ok { - break - } - - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-notifyCh - continue - } - - log.Fatal("Read() failed:", err) - } - } - wq.EventUnregister(&waitEntry) - - // The reader has completed. Now wait for the writer as well. - <-writerCompletedCh - - ep.Close() -} diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD deleted file mode 100644 index 43264b76d..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_echo", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go deleted file mode 100644 index 9b23df3a9..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and listens on a port. Data received by the server in the accepted -// connections is echoed back to the clients. -package main - -import ( - "bytes" - "flag" - "io" - "log" - "math/rand" - "net" - "os" - "strconv" - "strings" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var tap = flag.Bool("tap", false, "use tap istead of tun") -var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") - -type endpointWriter struct { - ep tcpip.Endpoint -} - -type tcpipError struct { - inner tcpip.Error -} - -func (e *tcpipError) Error() string { - return e.inner.String() -} - -func (e *endpointWriter) Write(p []byte) (int, error) { - var r bytes.Reader - r.Reset(p) - n, err := e.ep.Write(&r, tcpip.WriteOptions{}) - if err != nil { - return int(n), &tcpipError{ - inner: err, - } - } - if n != int64(len(p)) { - return int(n), io.ErrShortWrite - } - return int(n), nil -} - -func echo(wq *waiter.Queue, ep tcpip.Endpoint) { - defer ep.Close() - - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - w := endpointWriter{ - ep: ep, - } - - for { - _, err := ep.Read(&w, tcpip.ReadOptions{}) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-notifyCh - continue - } - - return - } - } -} - -func main() { - flag.Parse() - if len(flag.Args()) != 3 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>") - } - - tunName := flag.Arg(0) - addrName := flag.Arg(1) - portName := flag.Arg(2) - - rand.Seed(time.Now().UnixNano()) - - // Parse the mac address. - maddr, err := net.ParseMAC(*mac) - if err != nil { - log.Fatalf("Bad MAC address: %v", *mac) - } - - // Parse the IP address. Support both ipv4 and ipv6. - parsedAddr := net.ParseIP(addrName) - if parsedAddr == nil { - log.Fatalf("Bad IP address: %v", addrName) - } - - var addr tcpip.Address - var proto tcpip.NetworkProtocolNumber - if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) - proto = ipv4.ProtocolNumber - } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) - proto = ipv6.ProtocolNumber - } else { - log.Fatalf("Unknown IP type: %v", addrName) - } - - localPort, err := strconv.Atoi(portName) - if err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } - - // Create the stack with ip and tcp protocols, then add a tun-based - // NIC and address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - var fd int - if *tap { - fd, err = tun.OpenTAP(tunName) - } else { - fd, err = tun.Open(tunName) - } - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{ - FDs: []int{fd}, - MTU: mtu, - EthernetHeader: *tap, - Address: tcpip.LinkAddress(maddr), - }) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, linkEP); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) - if err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: 1, - }, - }) - - // Create TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if e != nil { - log.Fatal(e) - } - - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil { - log.Fatal("Bind failed: ", err) - } - - if err := ep.Listen(10); err != nil { - log.Fatal("Listen failed: ", err) - } - - // Wait for connections to appear. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - for { - n, wq, err := ep.Accept(nil) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-notifyCh - continue - } - - log.Fatal("Accept() failed:", err) - } - - go echo(wq, n) // S/R-SAFE: sample code. - } -} diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD deleted file mode 100644 index 45f503845..000000000 --- a/pkg/tcpip/seqnum/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "seqnum", - srcs = ["seqnum.go"], - visibility = ["//visibility:public"], -) diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go new file mode 100644 index 000000000..23e79811d --- /dev/null +++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seqnum diff --git a/pkg/tcpip/sock_err_list.go b/pkg/tcpip/sock_err_list.go new file mode 100644 index 000000000..0be1993af --- /dev/null +++ b/pkg/tcpip/sock_err_list.go @@ -0,0 +1,221 @@ +package tcpip + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type sockErrorElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (sockErrorElementMapper) linkerFor(elem *SockError) *SockError { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type sockErrorList struct { + head *SockError + tail *SockError +} + +// Reset resets list l to the empty state. +func (l *sockErrorList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *sockErrorList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *sockErrorList) Front() *SockError { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *sockErrorList) Back() *SockError { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *sockErrorList) Len() (count int) { + for e := l.Front(); e != nil; e = (sockErrorElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *sockErrorList) PushFront(e *SockError) { + linker := sockErrorElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + sockErrorElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *sockErrorList) PushBack(e *SockError) { + linker := sockErrorElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + sockErrorElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *sockErrorList) PushBackList(m *sockErrorList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + sockErrorElementMapper{}.linkerFor(l.tail).SetNext(m.head) + sockErrorElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *sockErrorList) InsertAfter(b, e *SockError) { + bLinker := sockErrorElementMapper{}.linkerFor(b) + eLinker := sockErrorElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + sockErrorElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *sockErrorList) InsertBefore(a, e *SockError) { + aLinker := sockErrorElementMapper{}.linkerFor(a) + eLinker := sockErrorElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + sockErrorElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *sockErrorList) Remove(e *SockError) { + linker := sockErrorElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + sockErrorElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + sockErrorElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type sockErrorEntry struct { + next *SockError + prev *SockError +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *sockErrorEntry) Next() *SockError { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *sockErrorEntry) Prev() *SockError { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *sockErrorEntry) SetNext(elem *SockError) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *sockErrorEntry) SetPrev(elem *SockError) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD deleted file mode 100644 index 49362333a..000000000 --- a/pkg/tcpip/stack/BUILD +++ /dev/null @@ -1,145 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "most_shards") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "neighbor_entry_list", - out = "neighbor_entry_list.go", - package = "stack", - prefix = "neighborEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*neighborEntry", - "Linker": "*neighborEntry", - }, -) - -go_template_instance( - name = "packet_buffer_list", - out = "packet_buffer_list.go", - package = "stack", - prefix = "PacketBuffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*PacketBuffer", - "Linker": "*PacketBuffer", - }, -) - -go_template_instance( - name = "tuple_list", - out = "tuple_list.go", - package = "stack", - prefix = "tuple", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*tuple", - "Linker": "*tuple", - }, -) - -go_library( - name = "stack", - srcs = [ - "addressable_endpoint_state.go", - "conntrack.go", - "headertype_string.go", - "icmp_rate_limit.go", - "iptables.go", - "iptables_state.go", - "iptables_targets.go", - "iptables_types.go", - "neighbor_cache.go", - "neighbor_entry.go", - "neighbor_entry_list.go", - "neighborstate_string.go", - "nic.go", - "nud.go", - "packet_buffer.go", - "packet_buffer_list.go", - "packet_buffer_unsafe.go", - "pending_packets.go", - "rand.go", - "registration.go", - "route.go", - "stack.go", - "stack_global_state.go", - "stack_options.go", - "transport_demuxer.go", - "tuple_list.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/ilist", - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/transport/tcpconntrack", - "//pkg/waiter", - "@org_golang_x_time//rate:go_default_library", - ], -) - -go_test( - name = "stack_x_test", - size = "medium", - srcs = [ - "addressable_endpoint_state_test.go", - "ndp_test.go", - "nud_test.go", - "stack_test.go", - "transport_demuxer_test.go", - "transport_test.go", - ], - shard_count = most_shards, - deps = [ - ":stack", - "//pkg/rand", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stack_test", - size = "small", - srcs = [ - "forwarding_test.go", - "neighbor_cache_test.go", - "neighbor_entry_test.go", - "nic_test.go", - "packet_buffer_test.go", - ], - library = ":stack", - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go deleted file mode 100644 index 140f146f6..000000000 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// TestAddressableEndpointStateCleanup tests that cleaning up an addressable -// endpoint state removes permanent addresses and leaves groups. -func TestAddressableEndpointStateCleanup(t *testing.T) { - var ep fakeNetworkEndpoint - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - var s stack.AddressableEndpointState - s.Init(&ep) - - addr := tcpip.AddressWithPrefix{ - Address: "\x01", - PrefixLen: 8, - } - - { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) - if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) - } - // We don't need the address endpoint. - ep.DecRef() - } - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep == nil { - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) - } - ep.DecRef() - } - - s.Cleanup() - if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } -} diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go deleted file mode 100644 index c987c1851..000000000 --- a/pkg/tcpip/stack/forwarding_test.go +++ /dev/null @@ -1,785 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "encoding/binary" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fwdTestNetHeaderLen = 12 - fwdTestNetDefaultPrefixLen = 8 - - // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match - // the MTU of loopback interfaces on linux systems. - fwdTestNetDefaultMTU = 65536 - - dstAddrOffset = 0 - srcAddrOffset = 1 - protocolNumberOffset = 2 -) - -var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil) -var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) - -// fwdTestNetworkEndpoint is a network-layer protocol endpoint. -// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fwdTestNetworkEndpoint struct { - AddressableEndpointState - - nic NetworkInterface - proto *fwdTestNetworkProtocol - dispatcher TransportDispatcher -} - -func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { - return nil -} - -func (*fwdTestNetworkEndpoint) Enabled() bool { - return true -} - -func (*fwdTestNetworkEndpoint) Disable() {} - -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) -} - -func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - if _, _, ok := f.proto.Parse(pkt); !ok { - return - } - - netHdr := pkt.NetworkHeader().View() - _, dst := f.proto.ParseAddresses(netHdr) - - addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) - if addressEndpoint != nil { - addressEndpoint.DecRef() - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) - return - } - - r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) - if err != nil { - return - } - defer r.Release() - - vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - pkt = NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: vv.ToView().ToVectorisedView(), - }) - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - _ = r.WriteHeaderIncludedPacket(pkt) -} - -func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen -} - -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - -func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return f.proto.Number() -} - -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) - b[dstAddrOffset] = r.RemoteAddress[0] - b[srcAddrOffset] = r.LocalAddress[0] - b[protocolNumberOffset] = byte(params.Protocol) - - return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { - panic("not implemented") -} - -func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error { - // The network header should not already be populated. - if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { - return &tcpip.ErrMalformedHeader{} - } - - return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) -} - -func (f *fwdTestNetworkEndpoint) Close() { - f.AddressableEndpointState.Cleanup() -} - -// Stats implements stack.NetworkEndpoint. -func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats { - return &fwdTestNetworkEndpointStats{} -} - -var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil) - -type fwdTestNetworkEndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} - -var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) - -// fwdTestNetworkProtocol is a network-layer protocol that implements Address -// resolution. -type fwdTestNetworkProtocol struct { - stack *Stack - - neigh *neighborCache - addrResolveDelay time.Duration - onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) - onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) - - mu struct { - sync.RWMutex - forwarding bool - } -} - -func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -func (*fwdTestNetworkProtocol) MinimumPacketSize() int { - return fwdTestNetHeaderLen -} - -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - -func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) -} - -func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) - if !ok { - return 0, false, false - } - return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true -} - -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint { - e := &fwdTestNetworkEndpoint{ - nic: nic, - proto: f, - dispatcher: dispatcher, - } - e.AddressableEndpointState.Init(e) - return e -} - -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} -} - -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} -} - -func (*fwdTestNetworkProtocol) Close() {} - -func (*fwdTestNetworkProtocol) Wait() {} - -func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { - if fn := f.proto.onLinkAddressResolved; fn != nil { - time.AfterFunc(f.proto.addrResolveDelay, func() { - fn(f.proto.neigh, addr, remoteLinkAddr) - }) - } - return nil -} - -func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if fn := f.proto.onResolveStaticAddress; fn != nil { - return fn(addr) - } - return "", false -} - -func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) Forwarding() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.forwarding - -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.forwarding = v -} - -// fwdTestPacketInfo holds all the information about an outbound packet. -type fwdTestPacketInfo struct { - RemoteLinkAddress tcpip.LinkAddress - LocalLinkAddress tcpip.LinkAddress - Pkt *PacketBuffer -} - -type fwdTestLinkEndpoint struct { - dispatcher NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - - // C is where outbound packets are queued. - C chan fwdTestPacketInfo -} - -// InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - e.InjectLinkAddr(protocol, "", pkt) -} - -// InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) -} - -// Attach saves the stack network-layer dispatcher for use later when packets -// are injected. -func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *fwdTestLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *fwdTestLinkEndpoint) MTU() uint32 { - return e.mtu -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { - caps := LinkEndpointCapabilities(0) - return caps | CapabilityResolutionRequired -} - -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - -// MaxHeaderLength returns the maximum size of the link layer header. Given it -// doesn't have a header, it just returns 0. -func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, - LocalLinkAddress: r.LocalLinkAddress, - Pkt: pkt, - } - - select { - case e.C <- p: - default: - } - - return nil -} - -// WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, pkt) - n++ - } - - return n, nil -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*fwdTestLinkEndpoint) Wait() {} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - panic("not implemented") -} - -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { - // Create a stack with the network protocol and two NICs. - s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { - proto.stack = s - return proto - }}, - }) - - // Enable forwarding. - s.SetForwarding(proto.Number(), true) - - // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "a", - } - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) - } - - // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "b", - } - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } - - nic, ok := s.nics[2] - if !ok { - t.Fatal("NIC 2 does not exist") - } - - if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { - proto.neigh = &l.neigh - } - - // Route all packets to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) - } - - return ep1, ep2 -} - -func TestForwardingWithStaticResolver(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolver(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any address will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithNoResolver(t *testing.T) { - // Create a network protocol without a resolver. - proto := &fwdTestNetworkProtocol{} - - // Whether or not we use the neighbor cache here does not matter since - // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto) - - // inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - select { - case <-ep2.C: - t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): - } -} - -func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) { - // Don't resolve the link address. - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto) - - const numPackets int = 5 - // These packets will all be enqueued in the packet queue to wait for link - // address resolution. - for i := 0; i < numPackets; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - // All packets should fail resolution. - // TODO(gvisor.dev/issue/5141): Use a fake clock. - for i := 0; i < numPackets; i++ { - select { - case got := <-ep2.C: - t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - case <-time.After(100 * time.Millisecond): - } - } -} - -func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] - - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go deleted file mode 100644 index 3b6ba9509..000000000 --- a/pkg/tcpip/stack/ndp_test.go +++ /dev/null @@ -1,5332 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "context" - "encoding/binary" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") - - defaultPrefixLen = 128 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 10 * time.Second - - // Extra time to use when waiting for an async event to not occur. - // - // Since a negative check is used to make sure an event did not happen, it is - // okay to use a smaller timeout compared to the positive case since execution - // stall in regards to the monotonic clock will not affect the expected - // outcome. - defaultAsyncNegativeEventTimeout = time.Second -) - -var ( - llAddr1 = header.LinkLocalAddr(linkAddr1) - llAddr2 = header.LinkLocalAddr(linkAddr2) - llAddr3 = header.LinkLocalAddr(linkAddr3) - llAddr4 = header.LinkLocalAddr(linkAddr4) - dstAddr = tcpip.FullAddress{ - Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - Port: 25, - } -) - -func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return tcpip.AddressWithPrefix{} - } - - addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - return tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } -} - -// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent -// tcpip.Subnet, and an address where the lower half of the address is composed -// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. -func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { - prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixBytes), - PrefixLen: 64, - } - - subnet := prefix.Subnet() - - return prefix, subnet, addrForSubnet(subnet, linkAddr) -} - -// ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionStatus. -type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - resolved bool - err tcpip.Error -} - -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool -} - -type ndpPrefixEvent struct { - nicID tcpip.NICID - prefix tcpip.Subnet - // true if prefix was discovered, false if invalidated. - discovered bool -} - -type ndpAutoGenAddrEventType int - -const ( - newAddr ndpAutoGenAddrEventType = iota - deprecatedAddr - invalidatedAddr -) - -type ndpAutoGenAddrEvent struct { - nicID tcpip.NICID - addr tcpip.AddressWithPrefix - eventType ndpAutoGenAddrEventType -} - -type ndpRDNSS struct { - addrs []tcpip.Address - lifetime time.Duration -} - -type ndpRDNSSEvent struct { - nicID tcpip.NICID - rdnss ndpRDNSS -} - -type ndpDNSSLEvent struct { - nicID tcpip.NICID - domainNames []string - lifetime time.Duration -} - -type ndpDHCPv6Event struct { - nicID tcpip.NICID - configuration ipv6.DHCPv6ConfigurationFromNDPRA -} - -var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) - -// ndpDispatcher implements NDPDispatcher so tests can know when various NDP -// related events happen for test purposes. -type ndpDispatcher struct { - dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool - prefixC chan ndpPrefixEvent - rememberPrefix bool - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent - dnsslC chan ndpDNSSLEvent - routeTable []tcpip.Route - dhcpv6ConfigurationC chan ndpDHCPv6Event -} - -// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { - if n.dadC != nil { - n.dadC <- ndpDADEvent{ - nicID, - addr, - resolved, - err, - } - } -} - -// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ - nicID, - addr, - true, - } - } - - return n.rememberRouter -} - -// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ - nicID, - addr, - false, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - true, - } - } - - return n.rememberPrefix -} - -// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - false, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - newAddr, - } - } - return true -} - -func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - deprecatedAddr, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - invalidatedAddr, - } - } -} - -// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption. -func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { - if c := n.rdnssC; c != nil { - c <- ndpRDNSSEvent{ - nicID, - ndpRDNSS{ - addrs, - lifetime, - }, - } - } -} - -// Implements ipv6.NDPDispatcher.OnDNSSearchListOption. -func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { - if n.dnsslC != nil { - n.dnsslC <- ndpDNSSLEvent{ - nicID, - domainNames, - lifetime, - } - } -} - -// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) { - if c := n.dhcpv6ConfigurationC; c != nil { - c <- ndpDHCPv6Event{ - nicID, - configuration, - } - } -} - -// channelLinkWithHeaderLength is a channel.Endpoint with a configurable -// header length. -type channelLinkWithHeaderLength struct { - *channel.Endpoint - headerLength uint16 -} - -func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { - return l.headerLength -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { - return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) -} - -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Should get the address immediately since we should not have performed - // DAD on it. - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatal(err) - } - - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } -} - -// TestDADResolve tests that an address successfully resolves after performing -// DAD for various values of DupAddrDetectTransmits and RetransmitTimer. -// Included in the subtests is a test to make sure that an invalid -// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. -// This tests also validates the NDP NS packet that is transmitted. -func TestDADResolve(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - dupAddrDetectTransmits uint8 - retransTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - name: "1:1s:1s", - dupAddrDetectTransmits: 1, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "2:1s:1s", - linkHeaderLen: 1, - dupAddrDetectTransmits: 2, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "1:2s:2s", - linkHeaderLen: 2, - dupAddrDetectTransmits: 1, - retransTimer: 2 * time.Second, - expectedRetransmitTimer: 2 * time.Second, - }, - // 0s is an invalid RetransmitTimer timer and will be fixed to - // the default RetransmitTimer value of 1s. - { - name: "1:0s:1s", - linkHeaderLen: 3, - dupAddrDetectTransmits: 1, - retransTimer: 0, - expectedRetransmitTimer: time.Second, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: stack.DADConfigurations{ - RetransmitTimer: test.retransTimer, - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // We add a default route so the call to FindRoute below will succeed - // once we have an assigned address. - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: addr3, - NIC: nicID, - }}) - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Make sure the address does not resolve before the resolution time has - // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Error(err) - } - // Should not get a route even if we specify the local address as the - // tentative address. - { - r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) - } - if r != nil { - r.Release() - } - } - { - r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) - } - if r != nil { - r.Release() - } - } - - if t.Failed() { - t.FailNow() - } - - // Wait for DAD to resolve. - select { - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Error(err) - } - // Should get a route using the address now that it is resolved. - { - r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if err != nil { - t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err) - } else if r.LocalAddress != addr1 { - t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) - } - r.Release() - } - { - r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if err != nil { - t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err) - } else if r.LocalAddress != addr1 { - t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) - } - if r != nil { - r.Release() - } - } - - if t.Failed() { - t.FailNow() - } - - // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { - t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) - } - - // Validate the sent Neighbor Solicitation messages. - for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, _ := e.ReadContext(context.Background()) - - // Make sure its an IPv6 packet. - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - snmc := header.SolicitedNodeAddr(addr1) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - // Check NDP NS packet. - // - // As per RFC 4861 section 4.3, a possible option is the Source Link - // Layer option, but this option MUST NOT be included when the source - // address of the packet is the unspecified address. - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions(nil), - )) - - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - }) - } -} - -func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) -} - -// TestDADFail tests to make sure that the DAD process fails if another node is -// detected to be performing DAD on the same address (receive an NS message from -// a node doing DAD for the same address), or if another node is detected to own -// the address already (receive an NA message for the tentative address). -func TestDADFail(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - rxPkt func(e *channel.Endpoint, tgt tcpip.Address) - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - }{ - { - name: "RxSolicit", - rxPkt: rxNDPSolicit, - getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborSolicit - }, - }, - { - name: "RxAdvert", - rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) - pkt := header.ICMPv6(hdr.Prepend(naSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(tgt) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) - }, - getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborAdvert - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - dadConfigs := stack.DefaultDADConfigurations() - dadConfigs.RetransmitTimer = time.Second * 2 - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Receive a packet to simulate an address conflict. - test.rxPkt(e, addr1) - - stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) - if got := stat.Value(); got != 1 { - t.Fatalf("got stat = %d, want = 1", got) - } - - // Wait for DAD to fail and make sure the address did - // not get resolved. - select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Attempting to add the address again should not fail if the address's - // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - }) - } -} - -func TestDADStop(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - stopFn func(t *testing.T, s *stack.Stack) - skipFinalAddrCheck bool - }{ - // Tests to make sure that DAD stops when an address is removed. - { - name: "Remove address", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is disabled. - { - name: "Disable NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is removed. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("RemoveNIC(%d): %s", nicID, err) - } - }, - // The NIC is removed so we can't check its addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - - dadConfigs := stack.DADConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - test.stopFn(t, s) - - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - - if !test.skipFinalAddrCheck { - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - } - - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Errorf("got NeighborSolicit = %d, want <= 1", got) - } - }) - } -} - -// TestSetNDPConfigurations tests that we can update and use per-interface NDP -// configurations without affecting the default NDP configurations or other -// interfaces' configurations. -func TestSetNDPConfigurations(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const nicID3 = 3 - - tests := []struct { - name string - dupAddrDetectTransmits uint8 - retransmitTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - "OK", - 1, - time.Second, - time.Second, - }, - { - "Invalid Retransmit Timer", - 1, - 0, - time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - })}, - }) - - expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("expected DAD event for %s", addr) - } - } - - // This NIC(1)'s NDP configurations will be updated to - // be different from the default. - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - - // Created before updating NIC(1)'s NDP configurations - // but updating NIC(1)'s NDP configurations should not - // affect other existing NICs. - if err := s.CreateNIC(nicID2, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - - // Update the configurations on NIC(1) to use DAD. - if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) - } else { - dad := ipv6Ep.(stack.DuplicateAddressDetector) - dad.SetDADConfigurations(stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - RetransmitTimer: test.retransmitTimer, - }) - } - - // Created after updating NIC(1)'s NDP configurations - // but the stack's default NDP configurations should not - // have been updated. - if err := s.CreateNIC(nicID3, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err) - } - - // Add addresses for each NIC. - addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) - } - addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) - } - expectDADEvent(nicID2, addr2) - addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) - } - expectDADEvent(nicID3, addr3) - - // Address should not be considered bound to NIC(1) yet - // (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Should get the address on NIC(2) and NIC(3) - // immediately since we should not have performed DAD on - // it as the stack was configured to not do DAD by - // default and we only updated the NDP configurations on - // NIC(1). - if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatal(err) - } - if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatal(err) - } - - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Wait for DAD to resolve. - select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatal(err) - } - }) - } -} - -// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options -// and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(0) - raPayload := pkt.MessageBody() - ra := header.NDPRouterAdvert(raPayload) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[2:], rl) - // Populate the Managed Address flag field. - if managedAddress { - // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) - // of the RA payload. - raPayload[1] |= (1 << 7) - } - // Populate the Other Configurations flag field. - if otherConfigurations { - // The Other Configurations flag field is the 6th bit of byte #1 - // (0-indexing) of the RA payload. - raPayload[1] |= (1 << 6) - } - opts := ra.Options() - opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) -} - -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) -} - -// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related -// fields set. -// -// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the -// DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) -} - -// raBuf returns a valid NDP Router Advertisement. -// -// Note, raBuf does not populate any of the RA fields other than the -// Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) -} - -// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix -// Information option. -// -// Note, raBufWithPI does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { - flags := uint8(0) - if onLink { - // The OnLink flag is the 7th bit in the flags byte. - flags |= 1 << 7 - } - if auto { - // The Address Auto-Configuration flag is the 6th bit in the - // flags byte. - flags |= 1 << 6 - } - - // A valid header.NDPPrefixInformation must be 30 bytes. - buf := [30]byte{} - // The first byte in a header.NDPPrefixInformation is the Prefix Length - // field. - buf[0] = uint8(prefix.PrefixLen) - // The 2nd byte within a header.NDPPrefixInformation is the Flags field. - buf[1] = flags - // The Valid Lifetime field starts after the 2nd byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[2:], vl) - // The Preferred Lifetime field starts after the 6th byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[6:], pl) - // The Prefix Address field starts after the 14th byte within a - // header.NDPPrefixInformation. - copy(buf[14:], prefix.Address) - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ - header.NDPPrefixInformation(buf[:]), - }) -} - -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: - } - }) - } -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) -} - -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - -func TestRouterDiscovery(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - } - - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") - } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) - - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) - - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } - - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) - - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) - - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) - - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) -} - -// TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. -func TestRouterDiscoveryMaxRouters(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { - linkAddr := []byte{2, 2, 3, 4, 5, 0} - linkAddr[5] = byte(i) - llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) - - if i <= ipv6.MaxDiscoveredDefaultRouters { - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - } else { - select { - case <-ndpDisp.routerC: - t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - default: - } - } - } -} - -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - -// Check e to make sure that the event is for prefix on nic with ID 1, and the -// discovered flag set to discovered. -func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { - return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) -} - -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - -func TestPrefixDiscovery(t *testing.T) { - prefix1, subnet1, _ := prefixSubnetAddr(0, "") - prefix2, subnet2, _ := prefixSubnetAddr(1, "") - prefix3, subnet3, _ := prefixSubnetAddr(2, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) - - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) - - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) - - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } - - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) -} - -func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - subnet := prefix.Subnet() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - // Receive an RA with prefix in an NDP Prefix Information option (PI) - // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - expectPrefixEvent(subnet, true) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): - } - - // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - case <-time.After(testInfiniteLifetime): - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - expectPrefixEvent(subnet, true) - - // Receive an RA with prefix with an infinite lifetime. - // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): - } - - // Receive an RA with 0 lifetime. - // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) - expectPrefixEvent(subnet, false) -} - -// TestPrefixDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. -func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} - - // Receive an RA with 2 more than the max number of discovered on-link - // prefixes. - for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { - prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} - prefixAddr[7] = byte(i) - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixAddr[:]), - PrefixLen: 64, - } - prefixes[i] = prefix.Subnet() - buf := [30]byte{} - buf[0] = uint8(prefix.PrefixLen) - buf[1] = 128 - binary.BigEndian.PutUint32(buf[2:], 10) - copy(buf[14:], prefix.Address) - - optSer[i] = header.NDPPrefixInformation(buf[:]) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < ipv6.MaxDiscoveredOnLinkPrefixes { - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } else { - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") - default: - } - } - } -} - -// Checks to see if list contains an IPv6 address, item. -func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: item, - } - - return containsAddr(list, protocolAddress) -} - -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// event type is set to eventType. -func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { - return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) -} - -// TestAutoGenAddr tests that an address is properly generated and invalidated -// when configured to do so. -func TestAutoGenAddr2(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } - - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } - - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } -} - -func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { - ret := "" - for _, c := range containList { - if !containsV6Addr(addrs, c) { - ret += fmt.Sprintf("should have %s in the list of addresses\n", c) - } - } - for _, c := range notContainList { - if containsV6Addr(addrs, c) { - ret += fmt.Sprintf("should not have %s in the list of addresses\n", c) - } - } - return ret -} - -// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when -// configured to do so as part of IPv6 Privacy Extensions. -func TestAutoGenTempAddr(t *testing.T) { - const ( - nicID = 1 - newMinVL = 5 - newMinVLDuration = newMinVL * time.Second - ) - - savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - ipv6.MaxDesyncFactor = time.Nanosecond - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - tests := []struct { - name string - dupAddrTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD disabled", - }, - { - name: "DAD enabled", - dupAddrTransmits: 1, - retransmitTimer: time.Second, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for i, test := range tests { - i := i - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() - - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred - // lifetimes. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } - }) -} - -// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not -// generated for auto generated link-local addresses. -func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { - const nicID = 1 - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = time.Nanosecond - - tests := []struct { - name string - dupAddrTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD disabled", - }, - { - name: "DAD enabled", - dupAddrTransmits: 1, - retransmitTimer: time.Second, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - }) -} - -// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address -// will not be generated until after DAD completes, even if a new Router -// Advertisement is received to refresh lifetimes. -func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { - const ( - nicID = 1 - dadTransmits = 1 - retransmitTimer = 2 * time.Second - ) - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = 0 - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Receive an RA to trigger SLAAC for prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - - // DAD on the stable address for prefix has not yet completed. Receiving a new - // RA that would refresh lifetimes should not generate a temporary SLAAC - // address for the prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - default: - } - - // Wait for DAD to complete for the stable address then expect the temporary - // address to be generated. - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } -} - -// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are -// regenerated. -func TestAutoGenTempAddrRegen(t *testing.T) { - const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Stop generating temporary addresses - ndpConfigs.AutoGenTempGlobalAddresses = false - if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else { - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) - } - - // Wait for all the temporary addresses to get invalidated. - tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} - invalidateAfter := newMinVLDuration - 2*regenAfter - for _, addr := range tempAddrs { - // Wait for a deprecation then invalidation event, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation jobs could execute in any - // order. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we shouldn't get a deprecation - // event after. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event = %+v", e) - } - case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - - invalidateAfter = regenAfter - } - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { - t.Fatal(mismatch) - } -} - -// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's -// regeneration job gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { - const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Deprecate the prefix. - // - // A new temporary address should be generated after the regeneration - // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) - expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): - } - - // Prefer the prefix again. - // - // A new temporary address should immediately be generated since the - // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr2, newAddr) - - // Increase the maximum lifetimes for temporary addresses to large values - // then refresh the lifetimes of the prefix. - // - // A new address should not be generated after the regeneration time that was - // expected for the previous check. This is because the preferred lifetime for - // the temporary addresses has increased, so it will take more time to - // regenerate a new temporary address. Note, new addresses are only - // regenerated after the preferred lifetime - the regenerate advance duration - // as paased. - ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second - ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second - ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): - } - - // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration job gets scheduled again. - // - // The maximum lifetime is the sum of the minimum lifetimes for temporary - // addresses + the time that has already passed since the last address was - // generated so that the regeneration job is needed to generate the next - // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout - ndpConfigs.MaxTempAddrValidLifetime = newLifetimes - ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) -} - -// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response -// to a mix of DAD conflicts and NIC-local conflicts. -func TestMixedSLAACAddrConflictRegen(t *testing.T) { - const ( - nicID = 1 - nicName = "nic" - lifetimeSeconds = 9999 - // From stack.maxSLAACAddrLocalRegenAttempts - maxSLAACAddrLocalRegenAttempts = 10 - // We use 2 more addreses than the maximum local regeneration attempts - // because we want to also trigger regeneration in response to a DAD - // conflicts for this test. - maxAddrs = maxSLAACAddrLocalRegenAttempts + 2 - dupAddrTransmits = 1 - retransmitTimer = time.Second - ) - - var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte - header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID) - - var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte - header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID) - - prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1) - var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix - var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix - var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix - addrBytes := []byte(subnet.ID()) - for i := 0; i < maxAddrs; i++ { - stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)), - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - // When generating temporary addresses, the resolved stable address for the - // SLAAC prefix will be the first address stable address generated for the - // prefix as we will not simulate address conflicts for the stable addresses - // in tests involving temporary addresses. Address conflicts for stable - // addresses will be done in their own tests. - tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address) - tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address) - } - - tests := []struct { - name string - addrs []tcpip.AddressWithPrefix - tempAddrs bool - initialExpect tcpip.AddressWithPrefix - nicNameFromID func(tcpip.NICID, string) string - }{ - { - name: "Stable addresses with opaque IIDs", - addrs: stableAddrsWithOpaqueIID[:], - nicNameFromID: func(tcpip.NICID, string) string { - return nicName - }, - }, - { - name: "Temporary addresses with opaque IIDs", - addrs: tempAddrsWithOpaqueIID[:], - tempAddrs: true, - initialExpect: stableAddrsWithOpaqueIID[0], - nicNameFromID: func(tcpip.NICID, string) string { - return nicName - }, - }, - { - name: "Temporary addresses with modified EUI64", - addrs: tempAddrsWithModifiedEUI64[:], - tempAddrs: true, - initialExpect: stableAddrWithModifiedEUI64, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: test.tempAddrs, - AutoGenAddressConflictRetries: 1, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: test.nicNameFromID, - }, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr2, - NIC: nicID, - }}) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - for j := 0; j < len(test.addrs)-1; j++ { - // The NIC will not attempt to generate an address in response to a - // NIC-local conflict after some maximum number of attempts. We skip - // creating a conflict for the address that would be generated as part - // of the last attempt so we can simulate a DAD conflict for this - // address and restart the NIC-local generation process. - if j == maxSLAACAddrLocalRegenAttempts-1 { - continue - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) - } - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() - - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } - - // Enable DAD. - ndpDisp.dadC = make(chan ndpDADEvent, 2) - if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else { - ndpEP := ipv6Ep.(stack.DuplicateAddressDetector) - ndpEP.SetDADConfigurations(stack.DADConfigurations{ - DupAddrDetectTransmits: dupAddrTransmits, - RetransmitTimer: retransmitTimer, - }) - } - - // Do SLAAC for prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - if test.initialExpect != (tcpip.AddressWithPrefix{}) { - expectAutoGenAddrEvent(test.initialExpect, newAddr) - expectDADEventAsync(test.initialExpect.Address) - } - - // The last local generation attempt should succeed, but we introduce a - // DAD failure to restart the local generation process. - addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] - expectAutoGenAddrAsyncEvent(addr, newAddr) - rxNDPSolicit(e, addr.Address) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - - // The last address generated should resolve DAD. - addr = test.addrs[len(test.addrs)-1] - expectAutoGenAddrAsyncEvent(addr, newAddr) - expectDADEventAsync(addr.Address) - - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - default: - } - }) - } -} - -// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, -// channel.Endpoint and stack.Stack. -// -// stack.Stack will have a default route through the router (llAddr3) installed -// and a static link-address (linkAddr3) added to the link address cache for the -// router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { - t.Helper() - ndpDisp := &ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - - if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) - } - return ndpDisp, e, s -} - -// addrForNewConnectionTo returns the local address used when creating a new -// connection to addr. -func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - if err := ep.Connect(addr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", addr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// addrForNewConnection returns the local address used when creating a new -// connection. -func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { - t.Helper() - - return addrForNewConnectionTo(t, s, dstAddr) -} - -// addrForNewConnectionWithAddr returns the local address used when creating a -// new connection with a specific local address. -func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - if err := ep.Bind(addr); err != nil { - t.Fatalf("ep.Bind(%+v): %s", addr, err) - } - if err := ep.Connect(dstAddr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when -// receiving a PI with 0 preferred lifetime. -func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) - - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) - - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } - - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } - - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) -} - -// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated -// when its preferred lifetime expires. -func TestAutoGenAddrJobDeprecation(t *testing.T) { - const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) - - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event") - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - - { - err := ep.Connect(dstAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) - } - } -} - -// Tests transitioning a SLAAC address's valid lifetime between finite and -// infinite values. -func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { - const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - infiniteVL uint32 - }{ - { - name: "EqualToInfiniteVL", - infiniteVL: infiniteVLSeconds, - }, - // Our implementation supports changing header.NDPInfiniteLifetime for tests - // such that a packet can be received where the lifetime field has a value - // greater than header.NDPInfiniteLifetime. Because of this, we test to make - // sure that receiving a value greater than header.NDPInfiniteLifetime is - // handled the same as when receiving a value equal to - // header.NDPInfiniteLifetime. - { - name: "MoreThanInfiniteVL", - infiniteVL: infiniteVLSeconds + 1, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } - }) -} - -// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an -// auto-generated address only gets updated when required to, as specified in -// RFC 4862 section 5.5.3.e. -func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { - const infiniteVL = 4294967295 - const newMinVL = 4 - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - ovl uint32 - nvl uint32 - evl uint32 - }{ - // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than - // it. - { - "LargeVLToVLLessThanMinVLForUpdate", - 9999, - 1, - newMinVL, - }, - { - "LargeVLTo0", - 9999, - 0, - newMinVL, - }, - { - "InfiniteVLToVLLessThanMinVLForUpdate", - infiniteVL, - 1, - newMinVL, - }, - { - "InfiniteVLTo0", - infiniteVL, - 0, - newMinVL, - }, - - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. - { - "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, - }, - - // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. - { - "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, - }, - { - "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, - }, - { - "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - - // - // Validate that the VL for the address got set - // to test.evl. - // - - // The address should not be invalidated until the effective valid - // lifetime has passed. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): - } - - // Wait for the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } - }) -} - -// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed -// by the user, its resources will be cleaned up and an invalidation event will -// be sent to the integrator. -func TestAutoGenAddrRemoval(t *testing.T) { - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive a PI to auto-generate an address. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr, newAddr) - - // Removing the address should result in an invalidation event - // immediately. - if err := s.RemoveAddress(1, addr.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - - // Wait for the original valid lifetime to make sure the original job got - // cancelled/cleaned up. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - -// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously -// assigned to the NIC but is in the permanentExpired state. -func TestAutoGenAddrAfterRemoval(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) - - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) - - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) - - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) - - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) - - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) -} - -// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that -// is already assigned to the NIC, the static address remains. -func TestAutoGenAddrStaticConflict(t *testing.T) { - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive a PI where the generated address will be the same as the one - // that we already have assigned statically. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") - default: - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Should not get an invalidation event after the PI's invalidation - // time. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } -} - -// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use -// opaque interface identifiers when configured to do so. -func TestAutoGenAddrWithOpaqueIID(t *testing.T) { - const nicID = 1 - const nicName = "nic1" - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } - - prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) - // addr1 and addr2 are the addresses that are expected to be generated when - // stack.Stack is configured to generate opaque interface identifiers as - // defined by RFC 7217. - addrBytes := []byte(subnet1.ID()) - addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), - PrefixLen: 64, - } - addrBytes = []byte(subnet2.ID()) - addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), - PrefixLen: 64, - } - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in a PI. - const validLifetimeSecondPrefix1 = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in a PI with a large valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } -} - -func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { - const nicID = 1 - const nicName = "nic" - const dadTransmits = 1 - const retransmitTimer = time.Second - const maxMaxRetries = 3 - const lifetimeSeconds = 10 - - // Needed for the temporary address sub test. - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MaxDesyncFactor = time.Nanosecond - - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix { - addrBytes := []byte(subnet.ID()) - return tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)), - PrefixLen: 64, - } - } - - expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) { - t.Helper() - - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - } - - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { - t.Helper() - - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } - - stableAddrForTempAddrTest := addrForSubnet(subnet, 0) - - addrTypes := []struct { - name string - ndpConfigs ipv6.NDPConfigurations - autoGenLinkLocal bool - prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix - addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix - }{ - { - name: "Global address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - return nil - - }, - addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { - return addrForSubnet(subnet, dadCounter) - }, - }, - { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{}, - autoGenLinkLocal: true, - prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { - return nil - }, - addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { - return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter) - }, - }, - { - name: "Temporary address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { - header.InitialTempIID(tempIIDHistory, nil, nicID) - - // Generate a stable SLAAC address so temporary addresses will be - // generated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) - - // The stable address will be assigned throughout the test. - return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} - }, - addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address) - }, - }, - } - - for _, addrType := range addrTypes { - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the parallel - // tests complete and limit the number of parallel tests running at the same - // time to reduce flakes. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run(addrType.name, func(t *testing.T) { - for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { - for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { - maxRetries := maxRetries - numFailures := numFailures - addrType := addrType - - t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := addrType.ndpConfigs - ndpConfigs.AutoGenAddressConflictRetries = maxRetries - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: addrType.autoGenLinkLocal, - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) - - // Simulate DAD conflicts so the address is regenerated. - for i := uint8(0); i < numFailures; i++ { - addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - - // Should not have any new addresses assigned to the NIC. - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Simulate a DAD conflict. - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, false, nil) - - // Attempting to add the address manually should not fail if the - // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) - } - if err := s.RemoveAddress(nicID, addr.Address); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) - } - expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{}) - } - - // Should not have any new addresses assigned to the NIC. - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // If we had less failures than generation attempts, we should have - // an address after DAD resolves. - if maxRetries+1 > numFailures { - addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, true) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { - t.Fatal(mismatch) - } - } - - // Should not attempt address generation again. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - } - }) - } -} - -// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is -// not made for SLAAC addresses generated with an IID based on the NIC's link -// address. -func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { - const nicID = 1 - const dadTransmits = 1 - const retransmitTimer = time.Second - const maxRetries = 3 - const lifetimeSeconds = 10 - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - addrTypes := []struct { - name string - ndpConfigs ipv6.NDPConfigurations - autoGenLinkLocal bool - subnet tcpip.Subnet - triggerSLAACFn func(e *channel.Endpoint) - }{ - { - name: "Global address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - subnet: subnet, - triggerSLAACFn: func(e *channel.Endpoint) { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - - }, - }, - { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{ - AutoGenAddressConflictRetries: maxRetries, - }, - autoGenLinkLocal: true, - subnet: header.IPv6LinkLocalPrefix.Subnet(), - triggerSLAACFn: func(e *channel.Endpoint) {}, - }, - } - - for _, addrType := range addrTypes { - addrType := addrType - - t.Run(addrType.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - addrType.triggerSLAACFn(e) - - addrBytes := []byte(addrType.subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:]) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) - - // Simulate a DAD conflict. - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - - // Should not attempt address regeneration. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } -} - -// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address -// generation in response to DAD conflicts does not refresh the lifetimes. -func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { - const nicID = 1 - const nicName = "nic" - const dadTransmits = 1 - const retransmitTimer = 2 * time.Second - const failureTimer = time.Second - const maxRetries = 1 - const lifetimeSeconds = 5 - - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - - addrBytes := []byte(subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) - - // Simulate a DAD conflict after some time has passed. - time.Sleep(failureTimer) - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - - // Let the next address resolve. - addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) - expectAutoGenAddrEvent(addr, newAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - - // Address should be deprecated/invalidated after the lifetime expires. - // - // Note, the remaining lifetime is calculated from when the PI was first - // processed. Since we wait for some time before simulating a DAD conflict - // and more time for the new address to resolve, the new address is only - // expected to be valid for the remaining time. The DAD conflict should - // not have reset the lifetimes. - // - // We expect either just the invalidation event or the deprecation event - // followed by the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if e.eventType == deprecatedAddr { - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") - } - } else { - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for auto gen addr event") - } -} - -// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event -// to the integrator when an RA is received with the NDP Recursive DNS Server -// option with at least one valid address. -func TestNDPRecursiveDNSServerDispatch(t *testing.T) { - tests := []struct { - name string - opt header.NDPRecursiveDNSServer - expected *ndpRDNSS - }{ - { - "Unspecified", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }), - nil, - }, - { - "Multicast", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }), - nil, - }, - { - "OptionTooSmall", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, - }), - nil, - }, - { - "0Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - }), - nil, - }, - { - "Valid1Address", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - }, - 2 * time.Second, - }, - }, - { - "Valid2Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - }, - time.Second, - }, - }, - { - "Valid3Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 0, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", - }, - 0, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - // We do not expect more than a single RDNSS - // event at any time for this test. - rdnssC: make(chan ndpRDNSSEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) - - if test.expected != nil { - select { - case e := <-ndpDisp.rdnssC: - if e.nicID != 1 { - t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) - } - if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { - t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) - } - if e.rdnss.lifetime != test.expected.lifetime { - t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) - } - default: - t.Fatal("expected an RDNSS option event") - } - } - - // Should have no more RDNSS options. - select { - case e := <-ndpDisp.rdnssC: - t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) - default: - } - }) - } -} - -// TestNDPDNSSearchListDispatch tests that the integrator is informed when an -// NDP DNS Search List option is received with at least one domain name in the -// search list. -func TestNDPDNSSearchListDispatch(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dnsslC: make(chan ndpDNSSLEvent, 3), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - optSer := header.NDPOptionsSerializer{ - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 0, 0, - 2, 'h', 'i', - 0, - }), - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 'i', - 0, - 2, 'a', 'm', - 2, 'm', 'e', - 0, - }), - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 1, 0, - 3, 'x', 'y', 'z', - 0, - 5, 'h', 'e', 'l', 'l', 'o', - 5, 'w', 'o', 'r', 'l', 'd', - 0, - 4, 't', 'h', 'i', 's', - 2, 'i', 's', - 1, 'a', - 4, 't', 'e', 's', 't', - 0, - }), - } - expected := []struct { - domainNames []string - lifetime time.Duration - }{ - { - domainNames: []string{ - "hi", - }, - lifetime: 0, - }, - { - domainNames: []string{ - "i", - "am.me", - }, - lifetime: time.Second, - }, - { - domainNames: []string{ - "xyz", - "hello.world", - "this.is.a.test", - }, - lifetime: 256 * time.Second, - }, - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - - for i, expected := range expected { - select { - case dnssl := <-ndpDisp.dnsslC: - if dnssl.nicID != nicID { - t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID) - } - if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" { - t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff) - } - if dnssl.lifetime != expected.lifetime { - t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime) - } - default: - t.Fatal("expected a DNSSL event") - } - } - - // Should have no more DNSSL options. - select { - case <-ndpDisp.dnsslC: - t.Fatal("unexpectedly got a DNSSL event") - default: - } -} - -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. -func TestCleanupNDPState(t *testing.T) { - const ( - lifetimeSeconds = 5 - maxRouterAndPrefixEvents = 4 - nicID1 = 1 - nicID2 = 2 - ) - - prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) - e2Addr1 := addrForSubnet(subnet1, linkAddr2) - e2Addr2 := addrForSubnet(subnet2, linkAddr2) - llAddrWithPrefix1 := tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 64, - } - llAddrWithPrefix2 := tcpip.AddressWithPrefix{ - Address: llAddr2, - PrefixLen: 64, - } - - tests := []struct { - name string - cleanupFn func(t *testing.T, s *stack.Stack) - keepAutoGenLinkLocal bool - maxAutoGenAddrEvents int - skipFinalAddrCheck bool - }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - - // A NIC should cleanup all NDP state when it is disabled. - { - name: "Disable NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - }, - - // A NIC should cleanup all NDP state when it is removed. - { - name: "Remove NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.RemoveNIC(nicID1); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) - } - if err := s.RemoveNIC(nicID2); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - // The NICs are removed so we can't check their addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - expectRouterEvent := func() (bool, ndpRouterEvent) { - select { - case e := <-ndpDisp.routerC: - return true, e - default: - } - - return false, ndpRouterEvent{} - } - - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } - - return false, ndpPrefixEvent{} - } - - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } - - return false, ndpAutoGenAddrEvent{} - } - - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() - - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() - - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } - - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } - - test.cleanupFn(t, s) - - // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() - if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotRouterEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < test.maxAutoGenAddrEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } - - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } - - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, - } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - - if !test.keepAutoGenLinkLocal { - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 - } - - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } - - if !test.skipFinalAddrCheck { - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - } - - // Should not get any more events (invalidation timers should have been - // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) - select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: - } - }) - } -} - -// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly -// informed when new information about what configurations are available via -// DHCPv6 is learned. -func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) { - t.Helper() - select { - case e := <-ndpDisp.dhcpv6ConfigurationC: - if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { - t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DHCPv6 configuration event") - } - } - - expectNoDHCPv6Event := func() { - t.Helper() - select { - case <-ndpDisp.dhcpv6ConfigurationC: - t.Fatal("unexpected DHCPv6 configuration event") - default: - } - } - - // Even if the first RA reports no DHCPv6 configurations are available, the - // dispatcher should get an event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) - // Receiving the same update again should not result in an event to the - // dispatcher. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to none. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - // - // Note, when the M flag is set, the O flag is redundant. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - // Even though the DHCPv6 flags are different, the effective configuration is - // the same so we should not receive a new event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() - - // Cycling the NIC should cause the last DHCPv6 configuration to be cleared. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() -} - -// TestRouterSolicitation tests the initial Router Solicitations that are sent -// when a NIC newly becomes enabled. -func TestRouterSolicitation(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - linkAddr tcpip.LinkAddress - nicAddr tcpip.Address - expectedSrcAddr tcpip.Address - expectedNDPOpts []header.NDPOption - maxRtrSolicit uint8 - rtrSolicitInt time.Duration - effectiveRtrSolicitInt time.Duration - maxRtrSolicitDelay time.Duration - effectiveMaxRtrSolicitDelay time.Duration - }{ - { - name: "Single RS with 2s delay and interval", - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 1, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 2 * time.Second, - effectiveMaxRtrSolicitDelay: 2 * time.Second, - }, - { - name: "Single RS with 4s delay and interval", - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 1, - rtrSolicitInt: 4 * time.Second, - effectiveRtrSolicitInt: 4 * time.Second, - maxRtrSolicitDelay: 4 * time.Second, - effectiveMaxRtrSolicitDelay: 4 * time.Second, - }, - { - name: "Two RS with delay", - linkHeaderLen: 1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - maxRtrSolicit: 2, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 500 * time.Millisecond, - effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, - }, - { - name: "Single RS without delay", - linkHeaderLen: 2, - linkAddr: linkAddr1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - expectedNDPOpts: []header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }, - maxRtrSolicit: 1, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS without delay and invalid zero interval", - linkHeaderLen: 3, - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: 0, - effectiveRtrSolicitInt: 4 * time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Three RS without delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 3, - rtrSolicitInt: 500 * time.Millisecond, - effectiveRtrSolicitInt: 500 * time.Millisecond, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS with invalid negative delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: -3 * time.Second, - effectiveMaxRtrSolicitDelay: time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) - - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } - - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } - - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } - - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } - - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } -} - -func TestStopStartSolicitingRouters(t *testing.T) { - const nicID = 1 - const delay = 0 - const interval = 500 * time.Millisecond - const maxRtrSolicitations = 3 - - tests := []struct { - name string - startFn func(t *testing.T, s *stack.Stack) - // first is used to tell stopFn that it is being called for the first time - // after router solicitations were last enabled. - stopFn func(t *testing.T, s *stack.Stack, first bool) - }{ - // Tests that when forwarding is enabled or disabled, router solicitations - // are stopped or started, respectively. - { - name: "Enable and disable forwarding", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - }, - - // Tests that when a NIC is enabled or disabled, router solicitations - // are started or stopped, respectively. - { - name: "Enable and disable NIC", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests that when a NIC is removed, router solicitations are stopped. We - // cannot start router solications on a removed NIC. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack, first bool) { - t.Helper() - - // Only try to remove the NIC the first time stopFn is called since it's - // impossible to remove an already removed NIC. - if !first { - return - } - - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Stop soliciting routers. - test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok = e.ReadContext(ctx); ok { - t.Fatal("should not have sent more than one RS message") - } - } - - // Stopping router solicitations after it has already been stopped should - // do nothing. - test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") - } - - // If test.startFn is nil, there is no way to restart router solications. - if test.startFn == nil { - return - } - - // Start soliciting routers. - test.startFn(t, s) - waitForPkt(delay + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - } - - // Starting router solicitations after it has already completed should do - // nothing. - test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after finishing router solicitations") - } - }) - } -} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go deleted file mode 100644 index 909912662..000000000 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ /dev/null @@ -1,1700 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "bytes" - "encoding/binary" - "fmt" - "math" - "math/rand" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" -) - -const ( - // entryStoreSize is the default number of entries that will be generated and - // added to the entry store. This number needs to be larger than the size of - // the neighbor cache to give ample opportunity for verifying behavior during - // cache overflows. Four times the size of the neighbor cache allows for - // three complete cache overflows. - entryStoreSize = 4 * neighborCacheSize - - // typicalLatency is the typical latency for an ARP or NDP packet to travel - // to a router and back. - typicalLatency = time.Millisecond - - // testEntryBroadcastAddr is a special address that indicates a packet should - // be sent to all nodes. - testEntryBroadcastAddr = tcpip.Address("broadcast") - - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - - // testEntryBroadcastLinkAddr is a special link address sent back to - // multicast neighbor probes. - testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") - - // infiniteDuration indicates that a task will not occur in our lifetime. - infiniteDuration = time.Duration(math.MaxInt64) -) - -// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor -// entries. The UpdatedAtNanos field is ignored due to a lack of a -// deterministic method to predict the time that an event will be dispatched. -func entryDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), - } -} - -// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to -// sort slices of entries for cases where ordering must be ignored. -func entryDiffOptsWithSort() []cmp.Option { - return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - })) -} - -func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { - config.resetInvalidFields() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - linkRes := &testNeighborResolver{ - clock: clock, - entries: newTestEntryStore(), - delay: typicalLatency, - } - linkRes.neigh.init(&nic{ - stack: &Stack{ - clock: clock, - nudDisp: nudDisp, - nudConfigs: config, - randomGenerator: rng, - }, - id: 1, - stats: makeNICStats(), - }, linkRes) - return linkRes -} - -// testEntryStore contains a set of IP to NeighborEntry mappings. -type testEntryStore struct { - mu sync.RWMutex - entriesMap map[tcpip.Address]NeighborEntry -} - -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) -} - -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) -} - -// newTestEntryStore returns a testEntryStore pre-populated with entries. -func newTestEntryStore() *testEntryStore { - store := &testEntryStore{ - entriesMap: make(map[tcpip.Address]NeighborEntry), - } - for i := 0; i < entryStoreSize; i++ { - addr := toAddress(i) - linkAddr := toLinkAddress(i) - - store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - } - } - return store -} - -// size returns the number of entries in the store. -func (s *testEntryStore) size() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.entriesMap) -} - -// entry returns the entry at index i. Returns an empty entry and false if i is -// out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { - return s.entryByAddr(toAddress(i)) -} - -// entryByAddr returns the entry matching addr for situations when the index is -// not available. Returns an empty entry and false if no entries match addr. -func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - entry, ok := s.entriesMap[addr] - return entry, ok -} - -// entries returns all entries in the store. -func (s *testEntryStore) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(s.entriesMap)) - s.mu.RLock() - defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { - addr := toAddress(i) - if entry, ok := s.entriesMap[addr]; ok { - entries = append(entries, entry) - } - } - return entries -} - -// set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { - addr := toAddress(i) - s.mu.Lock() - defer s.mu.Unlock() - if entry, ok := s.entriesMap[addr]; ok { - entry.LinkAddr = linkAddr - s.entriesMap[addr] = entry - } -} - -// testNeighborResolver implements LinkAddressResolver to emulate sending a -// neighbor probe. -type testNeighborResolver struct { - clock tcpip.Clock - neigh neighborCache - entries *testEntryStore - delay time.Duration - onLinkAddressRequest func() - dropReplies bool -} - -var _ LinkAddressResolver = (*testNeighborResolver)(nil) - -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { - if !r.dropReplies { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) - } - - // Execute post address resolution action, if available. - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -// fakeRequest emulates handling a response for a link address request. -func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { - if entry, ok := r.entries.entryByAddr(addr); ok { - r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } -} - -func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == testEntryBroadcastAddr { - return testEntryBroadcastLinkAddr, true - } - return "", false -} - -func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 0 -} - -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - -func TestNeighborCacheGetConfig(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - if got, want := linkRes.neigh.config(), c; got != want { - t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheSetConfig(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - linkRes.neigh.setConfig(c) - - if got, want := linkRes.neigh.config(), c; got != want { - t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheEntry(t *testing.T) { - c := DefaultNUDConfigurations() - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - clock.Advance(typicalLatency) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { - t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - - // No more events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheRemoveEntry(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - clock.Advance(typicalLatency) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - { - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - } -} - -type testContext struct { - clock *faketime.ManualClock - linkRes *testNeighborResolver - nudDisp *testNUDDispatcher -} - -func newTestContext(c NUDConfigurations) testContext { - nudDisp := &testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nudDisp, c, clock) - - return testContext{ - clock: clock, - linkRes: linkRes, - nudDisp: nudDisp, - } -} - -type overflowOptions struct { - startAtEntryIndex int - wantStaticEntries []NeighborEntry -} - -func (c *testContext) overflowCache(opts overflowOptions) error { - // Fill the neighbor cache to capacity to verify the LRU eviction strategy is - // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { - // Add a new entry - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) - - var wantEvents []testEntryEventInfo - - // When beyond the full capacity, the cache will evict an entry as per the - // LRU eviction strategy. Note that the number of static entries should not - // affect the total number of dynamic entries that can be added. - if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) - if !ok { - return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) - } - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - }, - }) - } - - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, testEntryEventInfo{ - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }) - - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Expect to find only the most recent entries. The order of entries reported - // by entries() is nondeterministic, so entries have to be sorted before - // comparison. - wantUnsortedEntries := opts.wantStaticEntries - for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) - } - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) - } - - if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { - return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } - - // No more events should have been dispatched. - c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy -// respects the dynamic entry count. -func TestNeighborCacheOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache -// eviction strategy respects the dynamic entry count when an entry is removed. -func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - // Remove the entry - c.linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that -// adding a duplicate static entry with the same link address does not dispatch -// any events. -func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { - config := DefaultNUDConfigurations() - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - // Remove the static entry that was just added - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - // No more events should have been dispatched. - c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that -// adding a duplicate static entry with a different link address dispatches a -// change event. -func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { - config := DefaultNUDConfigurations() - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - // Add a duplicate entry with a different link address - staticLinkAddr += "duplicate" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - }, - }, - } - c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } -} - -// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache -// eviction strategy respects the dynamic entry count when a static entry is -// added then removed. In this case, the dynamic entry count shouldn't have -// been touched. -func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - // Remove the static entry that was just added - c.linkRes.neigh.removeEntry(entry.Addr) - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU -// cache eviction strategy keeps count of the dynamic entry count when an entry -// is overwritten by a static entry. Static entries should not count towards -// the size of the LRU cache. -func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(typicalLatency) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - // Override the entry with a static one using the same address - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 1, - wantStaticEntries: []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - }, - }, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - opts := overflowOptions{ - startAtEntryIndex: 1, - wantStaticEntries: []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - }, - }, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheClear(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - // Add a dynamic entry. - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - // Add a static entry. - linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Clear should remove both dynamic and static entries. - linkRes.neigh.clear() - - // Remove events dispatched from clear() have no deterministic order so they - // need to be sorted beforehand. - wantUnsortedEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(wantUnsortedEvents, nudDisp.events, eventDiffOptsWithSort()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction -// strategy keeps count of the dynamic entry count when all entries are -// cleared. -func TestNeighborCacheClearThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(typicalLatency) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - // Clear the cache. - c.linkRes.neigh.clear() - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) - c.nudDisp.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - frequentlyUsedEntry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - // The following logic is very similar to overflowCache, but - // periodically refreshes the frequently used entry. - - // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) - } - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { - // Periodically refresh the frequently used entry - if i%(neighborCacheSize/2) == 0 { - if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) - } - } - - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - - // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) - if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - }, - }, - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is nondeterministic, so entries - // have to be sorted before comparison. - wantUnsortedEntries := []NeighborEntry{ - { - Addr: frequentlyUsedEntry.Addr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, - }, - } - - for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) - } - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) - } - - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } - - // No more events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheConcurrent(t *testing.T) { - const concurrentProcesses = 16 - - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - storeEntries := linkRes.entries.entries() - for _, entry := range storeEntries { - var wg sync.WaitGroup - for r := 0; r < concurrentProcesses; r++ { - wg.Add(1) - go func(entry NeighborEntry) { - defer wg.Done() - switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) - } - }(entry) - } - - // Wait for all goroutines to send a request - wg.Wait() - - // Process all the requests for a single entry concurrently - clock.Advance(typicalLatency) - } - - // All goroutines add in the same order and add more values than can fit in - // the cache. Our eviction strategy requires that the last entries are - // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is nondeterministic, so entries - // have to be sorted before comparison. - var wantUnsortedEntries []NeighborEntry - for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Errorf("linkRes.entries.entry(%d) not found", i) - } - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) - } - - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheReplace(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - // Add an entry - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - - // Verify the entry exists - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) - } - if t.Failed() { - t.FailNow() - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - } - - // Notify of a link address change - var updatedLinkAddr tcpip.LinkAddress - { - entry, ok := linkRes.entries.entry(1) - if !ok { - t.Fatal("linkRes.entries.entry(1) not found") - } - updatedLinkAddr = entry.LinkAddr - } - linkRes.entries.set(0, updatedLinkAddr) - linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - - // Requesting the entry again should start neighbor reachability confirmation. - // - // Verify the entry's new link address and the new state. - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - clock.Advance(config.DelayFirstProbeTime + typicalLatency) - } - - // Verify that the neighbor is now reachable. - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - clock.Advance(typicalLatency) - if err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - } -} - -func TestNeighborCacheResolutionFailed(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - // First, sanity check that resolution is working - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - } - - got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - - // Verify address resolution fails for an unknown address. - before := atomic.LoadUint32(&requestCount) - - entry.Addr += "2" - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - } - - maxAttempts := linkRes.neigh.config().MaxUnicastProbes - if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes -// probes and not retrieving a confirmation before the duration defined by -// MaxMulticastProbes * RetransmitTimer. -func TestNeighborCacheResolutionTimeout(t *testing.T) { - config := DefaultNUDConfigurations() - config.RetransmitTimer = time.Millisecond // small enough to cause timeout - - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nil, config, clock) - // large enough to cause timeout - linkRes.delay = time.Minute - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } -} - -// TestNeighborCacheRetryResolution simulates retrying communication after -// failing to perform address resolution. -func TestNeighborCacheRetryResolution(t *testing.T) { - config := DefaultNUDConfigurations() - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nil, config, clock) - // Simulate a faulty link. - linkRes.dropReplies = true - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - // Perform address resolution with a faulty link, which will fail. - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - } - - wantEntries := []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - }, - } - if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { - t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) - } - - // Retry address resolution with a working link. - linkRes.dropReplies = false - { - incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - if incompleteEntry.State != Incomplete { - t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) - } - clock.Advance(typicalLatency) - - select { - case <-ch: - if !ok { - t.Fatal("expected successful address resolution") - } - reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) - } - if reachableEntry.Addr != entry.Addr { - t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) - } - if reachableEntry.LinkAddr != entry.LinkAddr { - t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) - } - if reachableEntry.State != Reachable { - t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) - } - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - } -} - -func BenchmarkCacheClear(b *testing.B) { - b.StopTimer() - config := DefaultNUDConfigurations() - clock := &tcpip.StdClock{} - linkRes := newTestNeighborResolver(nil, config, clock) - linkRes.delay = 0 - - // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { - // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - b.Fatalf("linkRes.entries.entry(%d) not found", i) - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { - b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - select { - case <-ch: - default: - b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - } - - b.StartTimer() - linkRes.neigh.clear() - b.StopTimer() - } -} diff --git a/pkg/tcpip/stack/neighbor_entry_list.go b/pkg/tcpip/stack/neighbor_entry_list.go new file mode 100644 index 000000000..d78430080 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type neighborEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (neighborEntryElementMapper) linkerFor(elem *neighborEntry) *neighborEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type neighborEntryList struct { + head *neighborEntry + tail *neighborEntry +} + +// Reset resets list l to the empty state. +func (l *neighborEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *neighborEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *neighborEntryList) Front() *neighborEntry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *neighborEntryList) Back() *neighborEntry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *neighborEntryList) Len() (count int) { + for e := l.Front(); e != nil; e = (neighborEntryElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *neighborEntryList) PushFront(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + neighborEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *neighborEntryList) PushBack(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *neighborEntryList) PushBackList(m *neighborEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + neighborEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *neighborEntryList) InsertAfter(b, e *neighborEntry) { + bLinker := neighborEntryElementMapper{}.linkerFor(b) + eLinker := neighborEntryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + neighborEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *neighborEntryList) InsertBefore(a, e *neighborEntry) { + aLinker := neighborEntryElementMapper{}.linkerFor(a) + eLinker := neighborEntryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + neighborEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *neighborEntryList) Remove(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + neighborEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + neighborEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type neighborEntryEntry struct { + next *neighborEntry + prev *neighborEntry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) Next() *neighborEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) Prev() *neighborEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) SetNext(elem *neighborEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) SetPrev(elem *neighborEntry) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go deleted file mode 100644 index 47a9e2448..000000000 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ /dev/null @@ -1,3604 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "math/rand" - "strings" - "sync" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - - entryTestNICID tcpip.NICID = 1 - entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - - entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") - entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 -) - -// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current -// time. -func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { - clock.Advance(immediateDuration) -} - -// eventDiffOpts are the options passed to cmp.Diff to compare entry events. -// The UpdatedAtNanos field is ignored due to a lack of a deterministic method -// to predict the time that an event will be dispatched. -func eventDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), - } -} - -// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to -// sort slices of events for cases where ordering must be ignored. -func eventDiffOptsWithSort() []cmp.Option { - return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 - })) -} - -// The following unit tests exercise every state transition and verify its -// behavior with RFC 4681 and RFC 7048. -// -// | From | To | Cause | Update | Action | Event | -// | =========== | =========== | ========================================== | ======== | ===========| ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | -// | Unknown | Stale | Probe | | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | -// | Incomplete | Unreachable | Max probes sent without reply | | Notify | Changed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | -// | Reachable | Stale | Reachable timer expired | | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | -// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Stale | Stale | Override confirmation | LinkAddr | | Changed | -// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | -// | Stale | Delay | Packet sent | | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | | Changed | -// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | -// | Delay | Probe | Delay timer expired | | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | -// | Probe | Probe | Retransmit timer expired | | | Changed | -// | Probe | Unreachable | Max probes sent without reply | | Notify | Changed | -// | Unreachable | Incomplete | Packet queued | | Send probe | Changed | -// | Unreachable | Stale | Probe w/ different address | LinkAddr | | Changed | - -type testEntryEventType uint8 - -const ( - entryTestAdded testEntryEventType = iota - entryTestChanged - entryTestRemoved -) - -func (t testEntryEventType) String() string { - switch t { - case entryTestAdded: - return "add" - case entryTestChanged: - return "change" - case entryTestRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -// Fields are exported for use with cmp.Diff. -type testEntryEventInfo struct { - EventType testEntryEventType - NICID tcpip.NICID - Entry NeighborEntry -} - -func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) -} - -// testNUDDispatcher implements NUDDispatcher to validate the dispatching of -// events upon certain NUD state machine events. -type testNUDDispatcher struct { - mu sync.Mutex - events []testEntryEventInfo -} - -var _ NUDDispatcher = (*testNUDDispatcher)(nil) - -func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { - d.mu.Lock() - defer d.mu.Unlock() - d.events = append(d.events, e) -} - -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestAdded, - NICID: nicID, - Entry: entry, - }) -} - -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestChanged, - NICID: nicID, - Entry: entry, - }) -} - -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: nicID, - Entry: entry, - }) -} - -type entryTestLinkResolver struct { - mu sync.Mutex - probes []entryTestProbeInfo -} - -var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) - -type entryTestProbeInfo struct { - RemoteAddress tcpip.Address - RemoteLinkAddress tcpip.LinkAddress - LocalAddress tcpip.Address -} - -func (p entryTestProbeInfo) String() string { - return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) -} - -// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts -// to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { - p := entryTestProbeInfo{ - RemoteAddress: targetAddr, - RemoteLinkAddress: linkAddr, - LocalAddress: localAddr, - } - r.mu.Lock() - defer r.mu.Unlock() - r.probes = append(r.probes, p) - return nil -} - -// ResolveStaticAddress attempts to resolve address without sending requests. -// It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - return "", false -} - -// LinkAddressProtocol returns the network protocol of the addresses this -// resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return entryTestNetNumber -} - -func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { - clock := faketime.NewManualClock() - disp := testNUDDispatcher{} - nic := nic{ - LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint - - id: entryTestNICID, - stack: &Stack{ - clock: clock, - nudDisp: &disp, - nudConfigs: c, - randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), - }, - stats: makeNICStats(), - } - netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) - nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: netEP, - } - - var linkRes entryTestLinkResolver - // Stub out the neighbor cache to verify deletion from the cache. - l := &linkResolver{ - resolver: &linkRes, - } - l.neigh.init(&nic, &linkRes) - - entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state) - l.neigh.mu.Lock() - l.neigh.mu.cache[entryTestAddr1] = entry - l.neigh.mu.Unlock() - nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{ - header.IPv6ProtocolNumber: l, - } - - return entry, &disp, &linkRes, clock -} - -// TestEntryInitiallyUnknown verifies that the state of a newly created -// neighborEntry is Unknown. -func TestEntryInitiallyUnknown(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - if e.mu.neigh.State != Unknown { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.mu.Unlock() - - clock.Advance(c.RetransmitTimer) - - // No probes should have been sent. - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Unknown { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.mu.Unlock() - - clock.Advance(time.Hour) - - // No probes should have been sent. - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnknownToIncomplete(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } -} - -func TestEntryUnknownToStale(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - updatedAtNanos := e.mu.neigh.UpdatedAtNanos - e.mu.Unlock() - - clock.Advance(c.RetransmitTimer) - - // UpdatedAt should remain the same during address resolution. - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got != want { - t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) - } - e.mu.Unlock() - - clock.Advance(c.RetransmitTimer) - - // UpdatedAt should change after failing address resolution. Timing out after - // sending the last probe transitions the entry to Unreachable. - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - clock.Advance(c.RetransmitTimer) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() - - e.mu.Lock() - if got, notWant := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { - t.Errorf("expected e.mu.neigh.UpdatedAt to change, got = %q", got) - } - e.mu.Unlock() -} - -func TestEntryIncompleteToReachable(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if !e.mu.isRouter { - t.Errorf("got e.mu.isRouter = %t, want = true", e.mu.isRouter) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToUnreachable(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) - clock.Advance(waitFor) - - wantProbes := []entryTestProbeInfo{ - // The Incomplete-to-Incomplete state transition is tested here by - // verifying that 3 reachability probes were sent. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() - - e.mu.Lock() - if e.mu.neigh.State != Unreachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) - } - e.mu.Unlock() -} - -type testLocker struct{} - -var _ sync.Locker = (*testLocker)(nil) - -func (*testLocker) Lock() {} -func (*testLocker) Unlock() {} - -func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if got, want := e.mu.isRouter, true; got != want { - t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) - } - - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.mu.isRouter, false; got != want { - t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) - } - if ipv6EP.invalidatedRtr != e.mu.neigh.Addr { - t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.mu.neigh.Addr) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() - - e.mu.Lock() - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() -} - -func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryReachableToStaleWhenTimeout(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() - - e.mu.Lock() - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() -} - -func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToDelay(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - e.handleUpperLevelConfirmationLocked() - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 1 - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 1 - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToProbe(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.mu.Unlock() -} - -func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - { - wantProbes := []entryTestProbeInfo{ - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario: -// 1. Probe is received -// 2. Entry is created in Stale -// 3. Packet is queued on the entry -// 4. Entry transitions to Delay then Probe -// 5. Probe is sent -func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Probe to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - clock.Advance(c.DelayFirstProbeTime) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() - - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToUnreachable(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 - c.DelayFirstProbeTime = c.RetransmitTimer - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - // Observe each probe sent while in the Probe state. - for i := uint32(0); i < c.MaxUnicastProbes; i++ { - clock.Advance(c.RetransmitTimer) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) - } - - e.mu.Lock() - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.mu.Unlock() - } - - // Wait for the last probe to expire, causing a transition to Unreachable. - clock.Advance(c.RetransmitTimer) - e.mu.Lock() - if e.mu.neigh.State != Unreachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnreachableToIncomplete(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in - // their expected state. - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) - clock.Advance(waitFor) - - wantProbes := []entryTestProbeInfo{ - // The Incomplete-to-Incomplete state transition is tested here by - // verifying that 3 reachability probes were sent. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - if e.mu.neigh.State != Unreachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) - } - e.mu.Unlock() - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnreachableToStale(t *testing.T) { - wantProbes := []entryTestProbeInfo{ - // The Incomplete-to-Incomplete state transition is tested here by - // verifying that 3 reachability probes were sent. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = uint32(len(wantProbes)) - e, nudDisp, linkRes, clock := entryTestSetup(c) - - // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in - // their expected state. - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.mu.Unlock() - - waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) - clock.Advance(waitFor) - - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - if e.mu.neigh.State != Unreachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) - } - e.mu.Unlock() - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go deleted file mode 100644 index c0f956e53..000000000 --- a/pkg/tcpip/stack/nic_test.go +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) -var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) -var _ NDPEndpoint = (*testIPv6Endpoint)(nil) - -// An IPv6 NetworkEndpoint that throws away outgoing packets. -// -// We use this instead of ipv6.endpoint because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Endpoint struct { - AddressableEndpointState - - nic NetworkInterface - protocol *testIPv6Protocol - - invalidatedRtr tcpip.Address -} - -func (*testIPv6Endpoint) Enable() tcpip.Error { - return nil -} - -func (*testIPv6Endpoint) Enabled() bool { - return true -} - -func (*testIPv6Endpoint) Disable() {} - -// DefaultTTL implements NetworkEndpoint.DefaultTTL. -func (*testIPv6Endpoint) DefaultTTL() uint8 { - return 0 -} - -// MTU implements NetworkEndpoint.MTU. -func (e *testIPv6Endpoint) MTU() uint32 { - return e.nic.MTU() - header.IPv6MinimumSize -} - -// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. -func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { - return e.nic.MaxHeaderLength() + header.IPv6MinimumSize -} - -// WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { - return nil -} - -// WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { - // Our tests don't use this so we don't support it. - return 0, &tcpip.ErrNotSupported{} -} - -// WriteHeaderIncludedPacket implements -// NetworkEndpoint.WriteHeaderIncludedPacket. -func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error { - // Our tests don't use this so we don't support it. - return &tcpip.ErrNotSupported{} -} - -// HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} - -// Close implements NetworkEndpoint.Close. -func (e *testIPv6Endpoint) Close() { - e.AddressableEndpointState.Cleanup() -} - -// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. -func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { - e.invalidatedRtr = rtr -} - -// Stats implements NetworkEndpoint. -func (*testIPv6Endpoint) Stats() NetworkEndpointStats { - return &testIPv6EndpointStats{} -} - -var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil) - -type testIPv6EndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} - -// We use this instead of ipv6.protocol because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Protocol struct{} - -// Number implements NetworkProtocol.Number. -func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. -func (*testIPv6Protocol) MinimumPacketSize() int { - return header.IPv6MinimumSize -} - -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - -// ParseAddresses implements NetworkProtocol.ParseAddresses. -func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - h := header.IPv6(v) - return h.SourceAddress(), h.DestinationAddress() -} - -// NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint { - e := &testIPv6Endpoint{ - nic: nic, - protocol: p, - } - e.AddressableEndpointState.Init(e) - return e -} - -// SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { - return nil -} - -// Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { - return nil -} - -// Close implements NetworkProtocol.Close. -func (*testIPv6Protocol) Close() {} - -// Wait implements NetworkProtocol.Wait. -func (*testIPv6Protocol) Wait() {} - -// Parse implements NetworkProtocol.Parse. -func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - return 0, false, false -} - -func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { - // When the NIC is disabled, the only field that matters is the stats field. - // This test is limited to stats counter checks. - nic := nic{ - stats: makeNICStats(), - } - - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { - t.Errorf("got DisabledRx.Packets = %d, want = 0", got) - } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { - t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) - } - if got := nic.stats.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ - Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), - })) - - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { - t.Errorf("got DisabledRx.Packets = %d, want = 1", got) - } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { - t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) - } - if got := nic.stats.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } -} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go deleted file mode 100644 index e1253f310..000000000 --- a/pkg/tcpip/stack/nud_test.go +++ /dev/null @@ -1,813 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "math" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 - - defaultFakeRandomNum = 0.5 -) - -// fakeRand is a deterministic random number generator. -type fakeRand struct { - num float32 -} - -var _ stack.Rand = (*fakeRand)(nil) - -func (f *fakeRand) Float32() float32 { - return f.num -} - -func TestNUDFunctions(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - nicID tcpip.NICID - netProtoFactory []stack.NetworkProtocolFactory - extraLinkCapabilities stack.LinkEndpointCapabilities - expectedErr tcpip.Error - }{ - { - name: "Invalid NICID", - nicID: nicID + 1, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - expectedErr: &tcpip.ErrUnknownNICID{}, - }, - { - name: "No network protocol", - nicID: nicID, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With IPv6", - nicID: nicID, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With resolution capability", - nicID: nicID, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With IPv6 and resolution capability", - nicID: nicID, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - NetworkProtocols: test.netProtoFactory, - Clock: clock, - }) - - e := channel.New(0, 0, linkAddr1) - e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired - e.LinkEPCapabilities |= test.extraLinkCapabilities - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - configs := stack.DefaultNUDConfigurations() - configs.BaseReachableTime = time.Hour - - { - err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } - } - - { - gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if diff := cmp.Diff(configs, gotConfigs); diff != "" { - t.Errorf("got configs mismatch (-want +got):\n%s", diff) - } - } - } - - for _, addr := range []tcpip.Address{llAddr1, llAddr2} { - { - err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) - } - } - } - - { - wantErr := test.expectedErr - for i := 0; i < 2; i++ { - { - err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) - if diff := cmp.Diff(wantErr, err); diff != "" { - t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } - } - - if test.expectedErr != nil { - break - } - - // Removing a neighbor that does not exist should give us a bad address - // error. - wantErr = &tcpip.ErrBadAddress{} - } - } - - { - neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, - neighbors, - ); diff != "" { - t.Errorf("neighbors mismatch (-want +got):\n%s", diff) - } - } - } - - { - err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { - t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) - } else if len(neighbors) != 0 { - t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) - } - } - } - }) - } -} - -func TestDefaultNUDConfigurations(t *testing.T) { - const nicID = 1 - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: stack.DefaultNUDConfigurations(), - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) - } -} - -func TestNUDConfigurationsBaseReachableTime(t *testing.T) { - tests := []struct { - name string - baseReachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - baseReachableTime: 0, - want: defaultBaseReachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - baseReachableTime: time.Millisecond, - want: time.Millisecond, - }, - { - name: "MoreThanDefaultBaseReachableTime", - baseReachableTime: 2 * defaultBaseReachableTime, - want: 2 * defaultBaseReachableTime, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.BaseReachableTime = test.baseReachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.BaseReachableTime; got != test.want { - t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMinRandomFactor(t *testing.T) { - tests := []struct { - name string - minRandomFactor float32 - want float32 - }{ - // Invalid cases - { - name: "LessThanZero", - minRandomFactor: -1, - want: defaultMinRandomFactor, - }, - { - name: "EqualToZero", - minRandomFactor: 0, - want: defaultMinRandomFactor, - }, - // Valid cases - { - name: "MoreThanZero", - minRandomFactor: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MinRandomFactor = test.minRandomFactor - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MinRandomFactor; got != test.want { - t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { - tests := []struct { - name string - minRandomFactor float32 - maxRandomFactor float32 - want float32 - }{ - // Invalid cases - { - name: "LessThanZero", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: -1, - want: defaultMaxRandomFactor, - }, - { - name: "EqualToZero", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: 0, - want: defaultMaxRandomFactor, - }, - { - name: "LessThanMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor * 0.99, - want: defaultMaxRandomFactor, - }, - { - name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", - minRandomFactor: defaultMaxRandomFactor * 2, - maxRandomFactor: defaultMaxRandomFactor, - want: defaultMaxRandomFactor * 6, - }, - // Valid cases - { - name: "EqualToMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor, - want: defaultMinRandomFactor, - }, - { - name: "MoreThanMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor * 1.1, - want: defaultMinRandomFactor * 1.1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MinRandomFactor = test.minRandomFactor - c.MaxRandomFactor = test.maxRandomFactor - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxRandomFactor; got != test.want { - t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsRetransmitTimer(t *testing.T) { - tests := []struct { - name string - retransmitTimer time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - retransmitTimer: 0, - want: defaultRetransmitTimer, - }, - { - name: "LessThanMinimumRetransmitTimer", - retransmitTimer: minimumRetransmitTimer - time.Nanosecond, - want: defaultRetransmitTimer, - }, - // Valid cases - { - name: "EqualToMinimumRetransmitTimer", - retransmitTimer: minimumRetransmitTimer, - want: minimumBaseReachableTime, - }, - { - name: "LargetThanMinimumRetransmitTimer", - retransmitTimer: 2 * minimumBaseReachableTime, - want: 2 * minimumBaseReachableTime, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.RetransmitTimer = test.retransmitTimer - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.RetransmitTimer; got != test.want { - t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { - tests := []struct { - name string - delayFirstProbeTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - delayFirstProbeTime: 0, - want: defaultDelayFirstProbeTime, - }, - // Valid cases - { - name: "MoreThanZero", - delayFirstProbeTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.DelayFirstProbeTime = test.delayFirstProbeTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.DelayFirstProbeTime; got != test.want { - t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { - tests := []struct { - name string - maxMulticastProbes uint32 - want uint32 - }{ - // Invalid cases - { - name: "EqualToZero", - maxMulticastProbes: 0, - want: defaultMaxMulticastProbes, - }, - // Valid cases - { - name: "MoreThanZero", - maxMulticastProbes: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MaxMulticastProbes = test.maxMulticastProbes - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxMulticastProbes; got != test.want { - t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { - tests := []struct { - name string - maxUnicastProbes uint32 - want uint32 - }{ - // Invalid cases - { - name: "EqualToZero", - maxUnicastProbes: 0, - want: defaultMaxUnicastProbes, - }, - // Valid cases - { - name: "MoreThanZero", - maxUnicastProbes: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MaxUnicastProbes = test.maxUnicastProbes - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxUnicastProbes; got != test.want { - t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) - } - }) - } -} - -// TestNUDStateReachableTime verifies the correctness of the ReachableTime -// computation. -func TestNUDStateReachableTime(t *testing.T) { - tests := []struct { - name string - baseReachableTime time.Duration - minRandomFactor float32 - maxRandomFactor float32 - want time.Duration - }{ - { - name: "AllZeros", - baseReachableTime: 0, - minRandomFactor: 0, - maxRandomFactor: 0, - want: 0, - }, - { - name: "ZeroMaxRandomFactor", - baseReachableTime: time.Second, - minRandomFactor: 0, - maxRandomFactor: 0, - want: 0, - }, - { - name: "ZeroMinRandomFactor", - baseReachableTime: time.Second, - minRandomFactor: 0, - maxRandomFactor: 1, - want: time.Duration(defaultFakeRandomNum * float32(time.Second)), - }, - { - name: "FractionalRandomFactor", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 0.001, - maxRandomFactor: 0.002, - want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), - }, - { - name: "MinAndMaxRandomFactorsEqual", - baseReachableTime: time.Second, - minRandomFactor: 1, - maxRandomFactor: 1, - want: time.Second, - }, - { - name: "MinAndMaxRandomFactorsDifferent", - baseReachableTime: time.Second, - minRandomFactor: 1, - maxRandomFactor: 2, - want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), - }, - { - name: "MaxInt64", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 1, - maxRandomFactor: 1, - want: time.Duration(math.MaxInt64), - }, - { - name: "Overflow", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 1.5, - maxRandomFactor: 1.5, - want: time.Duration(math.MaxInt64), - }, - { - name: "DoubleOverflow", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 2.5, - maxRandomFactor: 2.5, - want: time.Duration(math.MaxInt64), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := stack.NUDConfigurations{ - BaseReachableTime: test.baseReachableTime, - MinRandomFactor: test.minRandomFactor, - MaxRandomFactor: test.maxRandomFactor, - } - // A fake random number generator is used to ensure deterministic - // results. - rng := fakeRand{ - num: defaultFakeRandomNum, - } - s := stack.NewNUDState(c, &rng) - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - }) - } -} - -// TestNUDStateRecomputeReachableTime exercises the ReachableTime function -// twice to verify recomputation of reachable time when the min random factor, -// max random factor, or base reachable time changes. -func TestNUDStateRecomputeReachableTime(t *testing.T) { - const defaultBase = time.Second - const defaultMin = 2.0 * defaultMaxRandomFactor - const defaultMax = 3.0 * defaultMaxRandomFactor - - tests := []struct { - name string - baseReachableTime time.Duration - minRandomFactor float32 - maxRandomFactor float32 - want time.Duration - }{ - { - name: "BaseReachableTime", - baseReachableTime: 2 * defaultBase, - minRandomFactor: defaultMin, - maxRandomFactor: defaultMax, - want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), - }, - { - name: "MinRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: defaultMax, - maxRandomFactor: defaultMax, - want: time.Duration(defaultMax * float32(defaultBase)), - }, - { - name: "MaxRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: defaultMin, - maxRandomFactor: defaultMin, - want: time.Duration(defaultMin * float32(defaultBase)), - }, - { - name: "BothRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: 2 * defaultMin, - maxRandomFactor: 2 * defaultMax, - want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), - }, - { - name: "BaseReachableTimeAndBothRandomFactors", - baseReachableTime: 2 * defaultBase, - minRandomFactor: 2 * defaultMin, - maxRandomFactor: 2 * defaultMax, - want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := stack.DefaultNUDConfigurations() - c.BaseReachableTime = defaultBase - c.MinRandomFactor = defaultMin - c.MaxRandomFactor = defaultMax - - // A fake random number generator is used to ensure deterministic - // results. - rng := fakeRand{ - num: defaultFakeRandomNum, - } - s := stack.NewNUDState(c, &rng) - old := s.ReachableTime() - - if got, want := s.ReachableTime(), old; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - - // Check for recomputation when changing the min random factor, the max - // random factor, the base reachability time, or any permutation of those - // three options. - c.BaseReachableTime = test.baseReachableTime - c.MinRandomFactor = test.minRandomFactor - c.MaxRandomFactor = test.maxRandomFactor - s.SetConfig(c) - - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - - // Verify that ReachableTime isn't recomputed when none of the - // configuration options change. The random factor is changed so that if - // a recompution were to occur, ReachableTime would change. - rng.num = defaultFakeRandomNum / 2.0 - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - }) - } -} diff --git a/pkg/tcpip/stack/packet_buffer_list.go b/pkg/tcpip/stack/packet_buffer_list.go new file mode 100644 index 000000000..ce7057d6b --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type PacketBufferElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (PacketBufferElementMapper) linkerFor(elem *PacketBuffer) *PacketBuffer { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type PacketBufferList struct { + head *PacketBuffer + tail *PacketBuffer +} + +// Reset resets list l to the empty state. +func (l *PacketBufferList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *PacketBufferList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *PacketBufferList) Front() *PacketBuffer { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *PacketBufferList) Back() *PacketBuffer { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *PacketBufferList) Len() (count int) { + for e := l.Front(); e != nil; e = (PacketBufferElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *PacketBufferList) PushFront(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + PacketBufferElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *PacketBufferList) PushBack(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *PacketBufferList) PushBackList(m *PacketBufferList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(m.head) + PacketBufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *PacketBufferList) InsertAfter(b, e *PacketBuffer) { + bLinker := PacketBufferElementMapper{}.linkerFor(b) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + PacketBufferElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *PacketBufferList) InsertBefore(a, e *PacketBuffer) { + aLinker := PacketBufferElementMapper{}.linkerFor(a) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + PacketBufferElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *PacketBufferList) Remove(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + PacketBufferElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + PacketBufferElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type PacketBufferEntry struct { + next *PacketBuffer + prev *PacketBuffer +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) Next() *PacketBuffer { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) Prev() *PacketBuffer { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) SetNext(elem *PacketBuffer) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) SetPrev(elem *PacketBuffer) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go deleted file mode 100644 index c6fa8da5f..000000000 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ /dev/null @@ -1,397 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -func TestPacketHeaderPush(t *testing.T) { - for _, test := range []struct { - name string - reserved int - link []byte - network []byte - transport []byte - data []byte - }{ - { - name: "construct empty packet", - }, - { - name: "construct link header only packet", - reserved: 60, - link: makeView(10), - }, - { - name: "construct link and network header only packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - }, - { - name: "construct header only packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - transport: makeView(30), - }, - { - name: "construct data only packet", - data: makeView(40), - }, - { - name: "construct L3 packet", - reserved: 60, - network: makeView(20), - transport: makeView(30), - data: makeView(40), - }, - { - name: "construct L2 packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - transport: makeView(30), - data: makeView(40), - }, - } { - t.Run(test.name, func(t *testing.T) { - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: test.reserved, - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), - }) - - allHdrSize := len(test.link) + len(test.network) + len(test.transport) - - // Check the initial values for packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - ReserveHeaderBytes: test.reserved, - Data: buffer.View(test.data).ToVectorisedView(), - }) - - // Push headers. - if v := test.transport; len(v) > 0 { - copy(pk.TransportHeader().Push(len(v)), v) - } - if v := test.network; len(v) > 0 { - copy(pk.NetworkHeader().Push(len(v)), v) - } - if v := test.link; len(v) > 0 { - copy(pk.LinkHeader().Push(len(v)), v) - } - - // Check the after values for packet. - if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { - t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { - t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), allHdrSize; got != want { - t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) - } - if got, want := pk.Size(), allHdrSize+len(test.data); got != want { - t.Errorf("After pk.Size() = %d, want %d", got, want) - } - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), - concatViews(test.link, test.network, test.transport, test.data)) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(test.link, test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(test.transport, test.data)) - }) - } -} - -func TestPacketHeaderConsume(t *testing.T) { - for _, test := range []struct { - name string - data []byte - link int - network int - transport int - }{ - { - name: "parse L2 packet", - data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), - link: 10, - network: 20, - transport: 30, - }, - { - name: "parse L3 packet", - data: concatViews(makeView(20), makeView(30), makeView(40)), - network: 20, - transport: 30, - }, - } { - t.Run(test.name, func(t *testing.T) { - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), - }) - - // Check the initial values for packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - Data: buffer.View(test.data).ToVectorisedView(), - }) - - // Consume headers. - if size := test.link; size > 0 { - if _, ok := pk.LinkHeader().Consume(size); !ok { - t.Fatalf("pk.LinkHeader().Consume() = false, want true") - } - } - if size := test.network; size > 0 { - if _, ok := pk.NetworkHeader().Consume(size); !ok { - t.Fatalf("pk.NetworkHeader().Consume() = false, want true") - } - } - if size := test.transport; size > 0 { - if _, ok := pk.TransportHeader().Consume(size); !ok { - t.Fatalf("pk.TransportHeader().Consume() = false, want true") - } - } - - allHdrSize := test.link + test.network + test.transport - - // Check the after values for packet. - if got, want := pk.ReservedHeaderBytes(), 0; got != want { - t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), 0; got != want { - t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), allHdrSize; got != want { - t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) - } - if got, want := pk.Size(), len(test.data); got != want { - t.Errorf("After pk.Size() = %d, want %d", got, want) - } - // After state of pk. - var ( - link = test.data[:test.link] - network = test.data[test.link:][:test.network] - transport = test.data[test.link+test.network:][:test.transport] - payload = test.data[allHdrSize:] - ) - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(link, network, transport, payload)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(network, transport, payload)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(transport, payload)) - }) - } -} - -func TestPacketHeaderConsumeDataTooShort(t *testing.T) { - data := makeView(10) - - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - // Consume should fail if pkt.Data is too short. - if _, ok := pk.LinkHeader().Consume(11); ok { - t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") - } - if _, ok := pk.NetworkHeader().Consume(11); ok { - t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") - } - if _, ok := pk.TransportHeader().Consume(11); ok { - t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") - } - - // Check packet should look the same as initial packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - Data: buffer.View(data).ToVectorisedView(), - }) -} - -func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: headerSize * int(numHeaderType), - }) - - for _, h := range []PacketHeader{ - pk.TransportHeader(), - pk.NetworkHeader(), - pk.LinkHeader(), - } { - t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { - h.Push(headerSize) - - defer func() { recover() }() - h.Push(headerSize) - t.Fatal("Second push should have panicked") - }) - } -} - -func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), - }) - - for _, h := range []PacketHeader{ - pk.LinkHeader(), - pk.NetworkHeader(), - pk.TransportHeader(), - } { - t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { - if _, ok := h.Consume(headerSize); !ok { - t.Fatal("First consume should succeed") - } - - defer func() { recover() }() - h.Consume(headerSize) - t.Fatal("Second consume should have panicked") - }) - } -} - -func TestPacketHeaderPushThenConsumePanics(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: headerSize * int(numHeaderType), - }) - - for _, h := range []PacketHeader{ - pk.TransportHeader(), - pk.NetworkHeader(), - pk.LinkHeader(), - } { - t.Run(h.typ.String(), func(t *testing.T) { - h.Push(headerSize) - - defer func() { recover() }() - h.Consume(headerSize) - t.Fatal("Consume should have panicked") - }) - } -} - -func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), - }) - - for _, h := range []PacketHeader{ - pk.LinkHeader(), - pk.NetworkHeader(), - pk.TransportHeader(), - } { - t.Run(h.typ.String(), func(t *testing.T) { - h.Consume(headerSize) - - defer func() { recover() }() - h.Push(headerSize) - t.Fatal("Push should have panicked") - }) - } -} - -func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { - t.Helper() - reserved := opts.ReserveHeaderBytes - if got, want := pk.ReservedHeaderBytes(), reserved; got != want { - t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), reserved; got != want { - t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), 0; got != want { - t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) - } - data := opts.Data.ToView() - if got, want := pk.Size(), len(data); got != want { - t.Errorf("Initial pk.Size() = %d, want %d", got, want) - } - checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) - checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) - // Check the initial values for each header. - checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) - checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) - checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) - // Check the initial valies for PayloadSince. - checkViewEqual(t, "Initial PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), data) -} - -func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { - t.Helper() - checkViewEqual(t, name+".View()", h.View(), want) -} - -func checkViewEqual(t *testing.T, what string, got, want buffer.View) { - t.Helper() - if !bytes.Equal(got, want) { - t.Errorf("%s = %x, want %x", what, got, want) - } -} - -func makeView(size int) buffer.View { - b := byte(size) - return bytes.Repeat([]byte{b}, size) -} - -func concatViews(views ...buffer.View) buffer.View { - var all buffer.View - for _, v := range views { - all = append(all, v...) - } - return all -} diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go new file mode 100644 index 000000000..462139b82 --- /dev/null +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -0,0 +1,678 @@ +// automatically generated by stateify. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *tuple) StateTypeName() string { + return "pkg/tcpip/stack.tuple" +} + +func (t *tuple) StateFields() []string { + return []string{ + "tupleEntry", + "tupleID", + "conn", + "direction", + } +} + +func (t *tuple) beforeSave() {} + +func (t *tuple) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.tupleEntry) + stateSinkObject.Save(1, &t.tupleID) + stateSinkObject.Save(2, &t.conn) + stateSinkObject.Save(3, &t.direction) +} + +func (t *tuple) afterLoad() {} + +func (t *tuple) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.tupleEntry) + stateSourceObject.Load(1, &t.tupleID) + stateSourceObject.Load(2, &t.conn) + stateSourceObject.Load(3, &t.direction) +} + +func (ti *tupleID) StateTypeName() string { + return "pkg/tcpip/stack.tupleID" +} + +func (ti *tupleID) StateFields() []string { + return []string{ + "srcAddr", + "srcPort", + "dstAddr", + "dstPort", + "transProto", + "netProto", + } +} + +func (ti *tupleID) beforeSave() {} + +func (ti *tupleID) StateSave(stateSinkObject state.Sink) { + ti.beforeSave() + stateSinkObject.Save(0, &ti.srcAddr) + stateSinkObject.Save(1, &ti.srcPort) + stateSinkObject.Save(2, &ti.dstAddr) + stateSinkObject.Save(3, &ti.dstPort) + stateSinkObject.Save(4, &ti.transProto) + stateSinkObject.Save(5, &ti.netProto) +} + +func (ti *tupleID) afterLoad() {} + +func (ti *tupleID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ti.srcAddr) + stateSourceObject.Load(1, &ti.srcPort) + stateSourceObject.Load(2, &ti.dstAddr) + stateSourceObject.Load(3, &ti.dstPort) + stateSourceObject.Load(4, &ti.transProto) + stateSourceObject.Load(5, &ti.netProto) +} + +func (cn *conn) StateTypeName() string { + return "pkg/tcpip/stack.conn" +} + +func (cn *conn) StateFields() []string { + return []string{ + "original", + "reply", + "manip", + "tcbHook", + "tcb", + "lastUsed", + } +} + +func (cn *conn) beforeSave() {} + +func (cn *conn) StateSave(stateSinkObject state.Sink) { + cn.beforeSave() + var lastUsedValue unixTime = cn.saveLastUsed() + stateSinkObject.SaveValue(5, lastUsedValue) + stateSinkObject.Save(0, &cn.original) + stateSinkObject.Save(1, &cn.reply) + stateSinkObject.Save(2, &cn.manip) + stateSinkObject.Save(3, &cn.tcbHook) + stateSinkObject.Save(4, &cn.tcb) +} + +func (cn *conn) afterLoad() {} + +func (cn *conn) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &cn.original) + stateSourceObject.Load(1, &cn.reply) + stateSourceObject.Load(2, &cn.manip) + stateSourceObject.Load(3, &cn.tcbHook) + stateSourceObject.Load(4, &cn.tcb) + stateSourceObject.LoadValue(5, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) }) +} + +func (ct *ConnTrack) StateTypeName() string { + return "pkg/tcpip/stack.ConnTrack" +} + +func (ct *ConnTrack) StateFields() []string { + return []string{ + "seed", + "buckets", + } +} + +func (ct *ConnTrack) StateSave(stateSinkObject state.Sink) { + ct.beforeSave() + stateSinkObject.Save(0, &ct.seed) + stateSinkObject.Save(1, &ct.buckets) +} + +func (ct *ConnTrack) afterLoad() {} + +func (ct *ConnTrack) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ct.seed) + stateSourceObject.Load(1, &ct.buckets) +} + +func (b *bucket) StateTypeName() string { + return "pkg/tcpip/stack.bucket" +} + +func (b *bucket) StateFields() []string { + return []string{ + "tuples", + } +} + +func (b *bucket) beforeSave() {} + +func (b *bucket) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + stateSinkObject.Save(0, &b.tuples) +} + +func (b *bucket) afterLoad() {} + +func (b *bucket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &b.tuples) +} + +func (u *unixTime) StateTypeName() string { + return "pkg/tcpip/stack.unixTime" +} + +func (u *unixTime) StateFields() []string { + return []string{ + "second", + "nano", + } +} + +func (u *unixTime) beforeSave() {} + +func (u *unixTime) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.second) + stateSinkObject.Save(1, &u.nano) +} + +func (u *unixTime) afterLoad() {} + +func (u *unixTime) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.second) + stateSourceObject.Load(1, &u.nano) +} + +func (it *IPTables) StateTypeName() string { + return "pkg/tcpip/stack.IPTables" +} + +func (it *IPTables) StateFields() []string { + return []string{ + "mu", + "v4Tables", + "v6Tables", + "modified", + "priorities", + "connections", + "reaperDone", + } +} + +func (it *IPTables) StateSave(stateSinkObject state.Sink) { + it.beforeSave() + stateSinkObject.Save(0, &it.mu) + stateSinkObject.Save(1, &it.v4Tables) + stateSinkObject.Save(2, &it.v6Tables) + stateSinkObject.Save(3, &it.modified) + stateSinkObject.Save(4, &it.priorities) + stateSinkObject.Save(5, &it.connections) + stateSinkObject.Save(6, &it.reaperDone) +} + +func (it *IPTables) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &it.mu) + stateSourceObject.Load(1, &it.v4Tables) + stateSourceObject.Load(2, &it.v6Tables) + stateSourceObject.Load(3, &it.modified) + stateSourceObject.Load(4, &it.priorities) + stateSourceObject.Load(5, &it.connections) + stateSourceObject.Load(6, &it.reaperDone) + stateSourceObject.AfterLoad(it.afterLoad) +} + +func (table *Table) StateTypeName() string { + return "pkg/tcpip/stack.Table" +} + +func (table *Table) StateFields() []string { + return []string{ + "Rules", + "BuiltinChains", + "Underflows", + } +} + +func (table *Table) beforeSave() {} + +func (table *Table) StateSave(stateSinkObject state.Sink) { + table.beforeSave() + stateSinkObject.Save(0, &table.Rules) + stateSinkObject.Save(1, &table.BuiltinChains) + stateSinkObject.Save(2, &table.Underflows) +} + +func (table *Table) afterLoad() {} + +func (table *Table) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &table.Rules) + stateSourceObject.Load(1, &table.BuiltinChains) + stateSourceObject.Load(2, &table.Underflows) +} + +func (r *Rule) StateTypeName() string { + return "pkg/tcpip/stack.Rule" +} + +func (r *Rule) StateFields() []string { + return []string{ + "Filter", + "Matchers", + "Target", + } +} + +func (r *Rule) beforeSave() {} + +func (r *Rule) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Filter) + stateSinkObject.Save(1, &r.Matchers) + stateSinkObject.Save(2, &r.Target) +} + +func (r *Rule) afterLoad() {} + +func (r *Rule) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Filter) + stateSourceObject.Load(1, &r.Matchers) + stateSourceObject.Load(2, &r.Target) +} + +func (fl *IPHeaderFilter) StateTypeName() string { + return "pkg/tcpip/stack.IPHeaderFilter" +} + +func (fl *IPHeaderFilter) StateFields() []string { + return []string{ + "Protocol", + "CheckProtocol", + "Dst", + "DstMask", + "DstInvert", + "Src", + "SrcMask", + "SrcInvert", + "InputInterface", + "InputInterfaceMask", + "InputInterfaceInvert", + "OutputInterface", + "OutputInterfaceMask", + "OutputInterfaceInvert", + } +} + +func (fl *IPHeaderFilter) beforeSave() {} + +func (fl *IPHeaderFilter) StateSave(stateSinkObject state.Sink) { + fl.beforeSave() + stateSinkObject.Save(0, &fl.Protocol) + stateSinkObject.Save(1, &fl.CheckProtocol) + stateSinkObject.Save(2, &fl.Dst) + stateSinkObject.Save(3, &fl.DstMask) + stateSinkObject.Save(4, &fl.DstInvert) + stateSinkObject.Save(5, &fl.Src) + stateSinkObject.Save(6, &fl.SrcMask) + stateSinkObject.Save(7, &fl.SrcInvert) + stateSinkObject.Save(8, &fl.InputInterface) + stateSinkObject.Save(9, &fl.InputInterfaceMask) + stateSinkObject.Save(10, &fl.InputInterfaceInvert) + stateSinkObject.Save(11, &fl.OutputInterface) + stateSinkObject.Save(12, &fl.OutputInterfaceMask) + stateSinkObject.Save(13, &fl.OutputInterfaceInvert) +} + +func (fl *IPHeaderFilter) afterLoad() {} + +func (fl *IPHeaderFilter) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fl.Protocol) + stateSourceObject.Load(1, &fl.CheckProtocol) + stateSourceObject.Load(2, &fl.Dst) + stateSourceObject.Load(3, &fl.DstMask) + stateSourceObject.Load(4, &fl.DstInvert) + stateSourceObject.Load(5, &fl.Src) + stateSourceObject.Load(6, &fl.SrcMask) + stateSourceObject.Load(7, &fl.SrcInvert) + stateSourceObject.Load(8, &fl.InputInterface) + stateSourceObject.Load(9, &fl.InputInterfaceMask) + stateSourceObject.Load(10, &fl.InputInterfaceInvert) + stateSourceObject.Load(11, &fl.OutputInterface) + stateSourceObject.Load(12, &fl.OutputInterfaceMask) + stateSourceObject.Load(13, &fl.OutputInterfaceInvert) +} + +func (l *neighborEntryList) StateTypeName() string { + return "pkg/tcpip/stack.neighborEntryList" +} + +func (l *neighborEntryList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *neighborEntryList) beforeSave() {} + +func (l *neighborEntryList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *neighborEntryList) afterLoad() {} + +func (l *neighborEntryList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *neighborEntryEntry) StateTypeName() string { + return "pkg/tcpip/stack.neighborEntryEntry" +} + +func (e *neighborEntryEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *neighborEntryEntry) beforeSave() {} + +func (e *neighborEntryEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *neighborEntryEntry) afterLoad() {} + +func (e *neighborEntryEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (p *PacketBufferList) StateTypeName() string { + return "pkg/tcpip/stack.PacketBufferList" +} + +func (p *PacketBufferList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (p *PacketBufferList) beforeSave() {} + +func (p *PacketBufferList) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.head) + stateSinkObject.Save(1, &p.tail) +} + +func (p *PacketBufferList) afterLoad() {} + +func (p *PacketBufferList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.head) + stateSourceObject.Load(1, &p.tail) +} + +func (e *PacketBufferEntry) StateTypeName() string { + return "pkg/tcpip/stack.PacketBufferEntry" +} + +func (e *PacketBufferEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *PacketBufferEntry) beforeSave() {} + +func (e *PacketBufferEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *PacketBufferEntry) afterLoad() {} + +func (e *PacketBufferEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (t *TransportEndpointID) StateTypeName() string { + return "pkg/tcpip/stack.TransportEndpointID" +} + +func (t *TransportEndpointID) StateFields() []string { + return []string{ + "LocalPort", + "LocalAddress", + "RemotePort", + "RemoteAddress", + } +} + +func (t *TransportEndpointID) beforeSave() {} + +func (t *TransportEndpointID) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.LocalPort) + stateSinkObject.Save(1, &t.LocalAddress) + stateSinkObject.Save(2, &t.RemotePort) + stateSinkObject.Save(3, &t.RemoteAddress) +} + +func (t *TransportEndpointID) afterLoad() {} + +func (t *TransportEndpointID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.LocalPort) + stateSourceObject.Load(1, &t.LocalAddress) + stateSourceObject.Load(2, &t.RemotePort) + stateSourceObject.Load(3, &t.RemoteAddress) +} + +func (g *GSOType) StateTypeName() string { + return "pkg/tcpip/stack.GSOType" +} + +func (g *GSOType) StateFields() []string { + return nil +} + +func (g *GSO) StateTypeName() string { + return "pkg/tcpip/stack.GSO" +} + +func (g *GSO) StateFields() []string { + return []string{ + "Type", + "NeedsCsum", + "CsumOffset", + "MSS", + "L3HdrLen", + "MaxSize", + } +} + +func (g *GSO) beforeSave() {} + +func (g *GSO) StateSave(stateSinkObject state.Sink) { + g.beforeSave() + stateSinkObject.Save(0, &g.Type) + stateSinkObject.Save(1, &g.NeedsCsum) + stateSinkObject.Save(2, &g.CsumOffset) + stateSinkObject.Save(3, &g.MSS) + stateSinkObject.Save(4, &g.L3HdrLen) + stateSinkObject.Save(5, &g.MaxSize) +} + +func (g *GSO) afterLoad() {} + +func (g *GSO) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &g.Type) + stateSourceObject.Load(1, &g.NeedsCsum) + stateSourceObject.Load(2, &g.CsumOffset) + stateSourceObject.Load(3, &g.MSS) + stateSourceObject.Load(4, &g.L3HdrLen) + stateSourceObject.Load(5, &g.MaxSize) +} + +func (t *TransportEndpointInfo) StateTypeName() string { + return "pkg/tcpip/stack.TransportEndpointInfo" +} + +func (t *TransportEndpointInfo) StateFields() []string { + return []string{ + "NetProto", + "TransProto", + "ID", + "BindNICID", + "BindAddr", + "RegisterNICID", + } +} + +func (t *TransportEndpointInfo) beforeSave() {} + +func (t *TransportEndpointInfo) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.NetProto) + stateSinkObject.Save(1, &t.TransProto) + stateSinkObject.Save(2, &t.ID) + stateSinkObject.Save(3, &t.BindNICID) + stateSinkObject.Save(4, &t.BindAddr) + stateSinkObject.Save(5, &t.RegisterNICID) +} + +func (t *TransportEndpointInfo) afterLoad() {} + +func (t *TransportEndpointInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.NetProto) + stateSourceObject.Load(1, &t.TransProto) + stateSourceObject.Load(2, &t.ID) + stateSourceObject.Load(3, &t.BindNICID) + stateSourceObject.Load(4, &t.BindAddr) + stateSourceObject.Load(5, &t.RegisterNICID) +} + +func (ep *multiPortEndpoint) StateTypeName() string { + return "pkg/tcpip/stack.multiPortEndpoint" +} + +func (ep *multiPortEndpoint) StateFields() []string { + return []string{ + "demux", + "netProto", + "transProto", + "endpoints", + "flags", + } +} + +func (ep *multiPortEndpoint) beforeSave() {} + +func (ep *multiPortEndpoint) StateSave(stateSinkObject state.Sink) { + ep.beforeSave() + stateSinkObject.Save(0, &ep.demux) + stateSinkObject.Save(1, &ep.netProto) + stateSinkObject.Save(2, &ep.transProto) + stateSinkObject.Save(3, &ep.endpoints) + stateSinkObject.Save(4, &ep.flags) +} + +func (ep *multiPortEndpoint) afterLoad() {} + +func (ep *multiPortEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ep.demux) + stateSourceObject.Load(1, &ep.netProto) + stateSourceObject.Load(2, &ep.transProto) + stateSourceObject.Load(3, &ep.endpoints) + stateSourceObject.Load(4, &ep.flags) +} + +func (l *tupleList) StateTypeName() string { + return "pkg/tcpip/stack.tupleList" +} + +func (l *tupleList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *tupleList) beforeSave() {} + +func (l *tupleList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *tupleList) afterLoad() {} + +func (l *tupleList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *tupleEntry) StateTypeName() string { + return "pkg/tcpip/stack.tupleEntry" +} + +func (e *tupleEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *tupleEntry) beforeSave() {} + +func (e *tupleEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *tupleEntry) afterLoad() {} + +func (e *tupleEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*tuple)(nil)) + state.Register((*tupleID)(nil)) + state.Register((*conn)(nil)) + state.Register((*ConnTrack)(nil)) + state.Register((*bucket)(nil)) + state.Register((*unixTime)(nil)) + state.Register((*IPTables)(nil)) + state.Register((*Table)(nil)) + state.Register((*Rule)(nil)) + state.Register((*IPHeaderFilter)(nil)) + state.Register((*neighborEntryList)(nil)) + state.Register((*neighborEntryEntry)(nil)) + state.Register((*PacketBufferList)(nil)) + state.Register((*PacketBufferEntry)(nil)) + state.Register((*TransportEndpointID)(nil)) + state.Register((*GSOType)(nil)) + state.Register((*GSO)(nil)) + state.Register((*TransportEndpointInfo)(nil)) + state.Register((*multiPortEndpoint)(nil)) + state.Register((*tupleList)(nil)) + state.Register((*tupleEntry)(nil)) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go deleted file mode 100644 index 92a0cb401..000000000 --- a/pkg/tcpip/stack/stack_test.go +++ /dev/null @@ -1,4463 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package stack_test contains tests for the stack. It is in its own package so -// that the tests can also validate that all definitions needed to implement -// transport and network protocols are properly exported by the stack package. -package stack_test - -import ( - "bytes" - "fmt" - "math" - "net" - "sort" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 - fakeDefaultPrefixLen = 8 - - // fakeControlProtocol is used for control packets that represent - // destination port unreachable. - fakeControlProtocol tcpip.TransportProtocolNumber = 2 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 - - dstAddrOffset = 0 - srcAddrOffset = 1 - protocolNumberOffset = 2 -) - -func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { - if addr, ok := s.GetMainNICAddress(nicID, proto); !ok { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto) - } else if addr != want { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want) - } - return nil -} - -// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fakeNetworkEndpoint struct { - stack.AddressableEndpointState - - mu struct { - sync.RWMutex - - enabled bool - } - - nic stack.NetworkInterface - proto *fakeNetworkProtocol - dispatcher stack.TransportDispatcher -} - -func (f *fakeNetworkEndpoint) Enable() tcpip.Error { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.enabled = true - return nil -} - -func (f *fakeNetworkEndpoint) Enabled() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.enabled -} - -func (f *fakeNetworkEndpoint) Disable() { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.enabled = false -} - -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) -} - -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { - if _, _, ok := f.proto.Parse(pkt); !ok { - return - } - - // Increment the received packet count in the protocol descriptor. - netHdr := pkt.NetworkHeader().View() - - dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) - addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) - if addressEndpoint == nil { - return - } - addressEndpoint.DecRef() - - f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ - - // Handle control packets. - if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), - fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), - // Nothing checks the error. - nil, /* transport error */ - pkt, - ) - return - } - - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) -} - -func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fakeNetHeaderLen -} - -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - -func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return f.proto.Number() -} - -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { - // Increment the sent packet count in the protocol descriptor. - f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ - - // Add the protocol's header to the packet and send it to the link - // endpoint. - hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) - pkt.NetworkProtocolNumber = fakeNetNumber - hdr[dstAddrOffset] = r.RemoteAddress[0] - hdr[srcAddrOffset] = r.LocalAddress[0] - hdr[protocolNumberOffset] = byte(params.Protocol) - - if r.Loop&stack.PacketLoop != 0 { - f.HandlePacket(pkt.Clone()) - } - if r.Loop&stack.PacketOut == 0 { - return nil - } - - return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { - panic("not implemented") -} - -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (f *fakeNetworkEndpoint) Close() { - f.AddressableEndpointState.Cleanup() -} - -// Stats implements NetworkEndpoint. -func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats { - return &fakeNetworkEndpointStats{} -} - -var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil) - -type fakeNetworkEndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} - -// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the -// number of packets sent and received via endpoints of this protocol. The index -// where packets are added is given by the packet's destination address MOD 10. -type fakeNetworkProtocol struct { - packetCount [10]int - sendPacketCount [10]int - defaultTTL uint8 - - mu struct { - sync.RWMutex - forwarding bool - } -} - -func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fakeNetNumber -} - -func (*fakeNetworkProtocol) MinimumPacketSize() int { - return fakeNetHeaderLen -} - -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - -func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { - return f.packetCount[int(intfAddr)%len(f.packetCount)] -} - -func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) -} - -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { - e := &fakeNetworkEndpoint{ - nic: nic, - proto: f, - dispatcher: dispatcher, - } - e.AddressableEndpointState.Init(e) - return e -} - -func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.DefaultTTLOption: - f.defaultTTL = uint8(*v) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.DefaultTTLOption: - *v = tcpip.DefaultTTLOption(f.defaultTTL) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -// Close implements NetworkProtocol.Close. -func (*fakeNetworkProtocol) Close() {} - -// Wait implements NetworkProtocol.Wait. -func (*fakeNetworkProtocol) Wait() {} - -// Parse implements NetworkProtocol.Parse. -func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) - if !ok { - return 0, false, false - } - pkt.NetworkProtocolNumber = fakeNetNumber - return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true -} - -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) Forwarding() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.forwarding -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) SetForwarding(v bool) { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.forwarding = v -} - -func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { - return &fakeNetworkProtocol{} -} - -// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify -// that LinkEndpoint.Attach was called. -type linkEPWithMockedAttach struct { - stack.LinkEndpoint - attached bool -} - -// Attach implements stack.LinkEndpoint.Attach. -func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { - l.LinkEndpoint.Attach(d) - l.attached = d != nil -} - -func (l *linkEPWithMockedAttach) isAttached() bool { - return l.attached -} - -// Checks to see if list contains an address. -func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { - for _, i := range list { - if i == item { - return true - } - } - - return false -} - -func TestNetworkReceive(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and two - // addresses attached to it: 1 & 2. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Make sure packet with wrong address is not delivered. - buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to first endpoint. - buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to second endpoint. - buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet that is too small is dropped. - buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } -} - -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error { - r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - return send(r, payload) -} - -func send(r *stack.Route, payload buffer.View) tcpip.Error { - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - })) -} - -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := sendTo(s, addr, payload); err != nil { - t.Error("sendTo failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("sendTo packet count: got = %d, want %d", got, want) - } -} - -func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := send(r, payload); err != nil { - t.Error("send failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("send packet count: got = %d, want %d", got, want) - } -} - -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { - t.Helper() - if gotErr := send(r, payload); gotErr != wantErr { - t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { - t.Helper() - if gotErr := sendTo(s, addr, payload); gotErr != wantErr { - t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we expect to receive it. - want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we do NOT expect to receive it. - want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { - t.Helper() - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got := fakeNet.PacketCount(localAddrByte); got != want { - t.Errorf("receive packet count: got = %d, want %d", got, want) - } -} - -func TestNetworkSend(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and one - // address: 1. The route table sends all packets through the only - // existing nic. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("NewNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", ep, nil) -} - -func TestNetworkSendMultiRoute(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Send a packet to an odd destination. - testSendTo(t, s, "\x05", ep1, nil) - - // Send a packet to an even destination. - testSendTo(t, s, "\x06", ep2, nil) -} - -func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { - r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - defer r.Release() - - if r.LocalAddress != expectedSrcAddr { - t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress) - } - - if r.RemoteAddress != dstAddr { - t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress) - } -} - -func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { - _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{}) - } -} - -// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to -// a NetworkDispatcher when the NIC is created. -func TestAttachToLinkEndpointImmediately(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - nicOpts stack.NICOptions - }{ - { - name: "Create enabled NIC", - nicOpts: stack.NICOptions{Disabled: false}, - }, - { - name: "Create disabled NIC", - nicOpts: stack.NICOptions{Disabled: true}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - - if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - }) - } -} - -func TestDisableUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - err := s.DisableNIC(1) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) - } -} - -func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := loopback.New() - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - checkNIC := func(enabled bool) { - t.Helper() - - allNICInfo := s.NICInfo() - nicInfo, ok := allNICInfo[nicID] - if !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } else if nicInfo.Flags.Running != enabled { - t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) - } - - if got := s.CheckNIC(nicID); got != enabled { - t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) - } - } - - // NIC should initially report itself as disabled. - checkNIC(false) - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - checkNIC(true) - - // If the NIC is not reporting a correct enabled status, we cannot trust the - // next check so end the test here. - if t.Failed() { - t.FailNow() - } - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - checkNIC(false) -} - -func TestRemoveUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - err := s.RemoveNIC(1) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) - } -} - -func TestRemoveNIC(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") - } -} - -func TestRouteWithDownNIC(t *testing.T) { - tests := []struct { - name string - downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error - }{ - { - name: "Disabled NIC", - downFn: (*stack.Stack).DisableNIC, - upFn: (*stack.Stack).EnableNIC, - }, - - // Once a NIC is removed, it cannot be brought up. - { - name: "Removed NIC", - downFn: (*stack.Stack).RemoveNIC, - }, - } - - const unspecifiedNIC = 0 - const nicID1 = 1 - const nicID2 = 2 - const addr1 = tcpip.Address("\x01") - const addr2 = tcpip.Address("\x02") - const nic1Dst = tcpip.Address("\x05") - const nic2Dst = tcpip.Address("\x06") - - setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } - - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) - } - - return s, ep1, ep2 - } - - // Tests that routes through a down NIC are not used when looking up a route - // for a destination. - t.Run("Find", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, _, _ := setup(t) - - // Test routes to odd address. - testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) - testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) - testRoute(t, s, nicID1, addr1, "\x05", addr1) - - // Test routes to even address. - testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) - testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) - testRoute(t, s, nicID2, addr2, "\x06", addr2) - - // Bringing NIC1 down should result in no routes to odd addresses. Routes to - // even addresses should continue to be available as NIC2 is still up. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) - testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) - testRoute(t, s, nicID2, addr2, nic2Dst, addr2) - - // Bringing NIC2 down should result in no routes to even addresses. No - // route should be available to any address as routes to odd addresses - // were made unavailable by bringing NIC1 down above. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - - if upFn := test.upFn; upFn != nil { - // Bringing NIC1 up should make routes to odd addresses available - // again. Routes to even addresses should continue to be unavailable - // as NIC2 is still down. - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) - testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) - testRoute(t, s, nicID1, addr1, nic1Dst, addr1) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - } - }) - } - }) - - // Tests that writing a packet using a Route through a down NIC fails. - t.Run("WritePacket", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, ep1, ep2 := setup(t) - - r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) - } - defer r1.Release() - - r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) - } - defer r2.Release() - - // If we failed to get routes r1 or r2, we cannot proceed with the test. - if t.Failed() { - t.FailNow() - } - - buf := buffer.View([]byte{1}) - testSend(t, r1, ep1, buf) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC1 after being brought down should fail. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC2 after being brought down should fail. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) - - if upFn := test.upFn; upFn != nil { - // Writes with Routes that use NIC1 after being brought up should - // succeed. - // - // TODO(gvisor.dev/issue/1491): Should we instead completely - // invalidate all Routes that were bound to a NIC that was brought - // down at some point? - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) - } - }) - } - }) -} - -func TestRoutes(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Test routes to odd address. - testRoute(t, s, 0, "", "\x05", "\x01") - testRoute(t, s, 0, "\x01", "\x05", "\x01") - testRoute(t, s, 1, "\x01", "\x05", "\x01") - testRoute(t, s, 0, "\x03", "\x05", "\x03") - testRoute(t, s, 1, "\x03", "\x05", "\x03") - - // Test routes to even address. - testRoute(t, s, 0, "", "\x06", "\x02") - testRoute(t, s, 0, "\x02", "\x06", "\x02") - testRoute(t, s, 2, "\x02", "\x06", "\x02") - testRoute(t, s, 0, "\x04", "\x06", "\x04") - testRoute(t, s, 2, "\x04", "\x06", "\x04") - - // Try to send to odd numbered address from even numbered ones, then - // vice-versa. - testNoRoute(t, s, 0, "\x02", "\x05") - testNoRoute(t, s, 2, "\x02", "\x05") - testNoRoute(t, s, 0, "\x04", "\x05") - testNoRoute(t, s, 2, "\x04", "\x05") - - testNoRoute(t, s, 0, "\x01", "\x06") - testNoRoute(t, s, 1, "\x01", "\x06") - testNoRoute(t, s, 0, "\x03", "\x06") - testNoRoute(t, s, 1, "\x03", "\x06") -} - -func TestAddressRemoval(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Send and receive packets, and verify they are received. - buf[dstAddrOffset] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) - - // Check that removing the same address fails. - err := s.RemoveAddress(1, localAddr) - if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) - } -} - -func TestAddressRemovalWithRouteHeld(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - // Send and receive packets, and verify they are received. - buf[dstAddrOffset] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) - - // Check that removing the same address fails. - { - err := s.RemoveAddress(1, localAddr) - if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) - } - } -} - -func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { - t.Helper() - info, ok := s.NICInfo()[nicID] - if !ok { - t.Fatalf("NICInfo() failed to find nicID=%d", nicID) - } - if len(addr) == 0 { - // No address given, verify that there is no address assigned to the NIC. - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) - } - } - return - } - // Address given, verify the address is assigned to the NIC and no other - // address is. - found := false - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber { - if a.AddressWithPrefix.Address == addr { - found = true - } else { - t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) - } - } - } - if !found { - t.Errorf("verify address: couldn't find %s on the NIC", addr) - } -} - -func TestEndpointExpiration(t *testing.T) { - const ( - localAddrByte byte = 0x01 - remoteAddr tcpip.Address = "\x03" - noAddr tcpip.Address = "" - nicID tcpip.NICID = 1 - ) - localAddr := tcpip.Address([]byte{localAddrByte}) - - for _, promiscuous := range []bool{true, false} { - for _, spoofing := range []bool{true, false} { - t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - buf[dstAddrOffset] = localAddrByte - - if promiscuous { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - } - - if spoofing { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - } - - // 1. No Address yet, send should only work for spoofing, receive for - // promiscuous mode. - //----------------------- - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) - } - - // 2. Add Address, everything should work. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 3. Remove the address, send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) - } - - // 4. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 5. Take a reference to the endpoint by getting a route. Verify that - // we can still send/receive, including sending using the route. - //----------------------- - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 6. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) - } - - // 7. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 8. Remove the route, sendTo/recv should still work. - //----------------------- - r.Release() - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 9. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) - } - }) - } - } -} - -func TestPromiscuousMode(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it doesn't get delivered as we don't - // have a matching endpoint. - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - - // Set promiscuous mode, then check that packet is delivered. - if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - - // Check that we can't get a route as there is no local address. - _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{}) - } - - // Set promiscuous mode to false, then check that packet can't be - // delivered anymore. - if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -// TestExternalSendWithHandleLocal tests that the stack creates a non-local -// route when spoofing or promiscuous mode are enabled. -// -// This test makes sure that packets are transmitted from the stack. -func TestExternalSendWithHandleLocal(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - - localAddr = tcpip.Address("\x01") - dstAddr = tcpip.Address("\x03") - ) - - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - - tests := []struct { - name string - configureStack func(*testing.T, *stack.Stack) - }{ - { - name: "Default", - configureStack: func(*testing.T, *stack.Stack) {}, - }, - { - name: "Spoofing", - configureStack: func(t *testing.T, s *stack.Stack) { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) - } - }, - }, - { - name: "Promiscuous", - configureStack: func(t *testing.T, s *stack.Stack) { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, handleLocal := range []bool{true, false} { - t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - HandleLocal: handleLocal, - }) - - ep := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) - - test.configureStack(t, s) - - r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) - } - defer r.Release() - - if r.LocalAddress != localAddr { - t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - - if n := ep.Drain(); n != 0 { - t.Fatalf("got ep.Drain() = %d, want = 0", n) - } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ - Protocol: fakeTransNumber, - TTL: 123, - TOS: stack.DefaultTOS, - }, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.NewView(10).ToVectorisedView(), - })); err != nil { - t.Fatalf("r.WritePacket(nil, _, _): %s", err) - } - if n := ep.Drain(); n != 1 { - t.Fatalf("got ep.Drain() = %d, want = 1", n) - } - }) - } - }) - } -} - -func TestSpoofingWithAddress(t *testing.T) { - localAddr := tcpip.Address("\x01") - nonExistentLocalAddr := tcpip.Address("\x02") - dstAddr := tcpip.Address("\x03") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet works. - testSendTo(t, s, dstAddr, ep, nil) - testSend(t, r, ep, nil) - - // FindRoute should also work with a local address that exists on the NIC. - r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != localAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet using the route works. - testSend(t, r, ep, nil) -} - -func TestSpoofingNoAddress(t *testing.T) { - nonExistentLocalAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet works. - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) -} - -func verifyRoute(gotRoute, wantRoute *stack.Route) error { - if gotRoute.LocalAddress != wantRoute.LocalAddress { - return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) - } - if gotRoute.RemoteAddress != wantRoute.RemoteAddress { - return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) - } - if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want { - return fmt.Errorf("bad remote link address: got %s, want = %s", got, want) - } - if gotRoute.NextHop != wantRoute.NextHop { - return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) - } - return nil -} - -func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - s.SetRouteTable([]tcpip.Route{}) - - // If there is no endpoint, it won't work. - { - _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) - } - } - - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) - } - r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - var wantRoute stack.Route - wantRoute.LocalAddress = header.IPv4Any - wantRoute.RemoteAddress = header.IPv4Broadcast - if err := verifyRoute(r, &wantRoute); err != nil { - t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - - // If the NIC doesn't exist, it won't work. - { - _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { - t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) - } - } -} - -func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} - // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} - nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") - // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} - nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") - - // Create a new stack with two NICs. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - if err := s.CreateNIC(2, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) - } - - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) - } - - // Set the initial route table. - rt := []tcpip.Route{ - {Destination: nic1Addr.Subnet(), NIC: 1}, - {Destination: nic2Addr.Subnet(), NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, - } - s.SetRouteTable(rt) - - // When an interface is given, the route for a broadcast goes through it. - r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - var wantRoute stack.Route - wantRoute.LocalAddress = nic1Addr.Address - wantRoute.RemoteAddress = header.IPv4Broadcast - if err := verifyRoute(r, &wantRoute); err != nil { - t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - - // When an interface is not given, it consults the route table. - // 1. Case: Using the default route. - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - wantRoute = stack.Route{} - wantRoute.LocalAddress = nic2Addr.Address - wantRoute.RemoteAddress = header.IPv4Broadcast - if err := verifyRoute(r, &wantRoute); err != nil { - t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) - } - - // 2. Case: Having an explicit route for broadcast will select that one. - rt = append( - []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, - }, - rt..., - ) - s.SetRouteTable(rt) - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - wantRoute = stack.Route{} - wantRoute.LocalAddress = nic1Addr.Address - wantRoute.RemoteAddress = header.IPv4Broadcast - if err := verifyRoute(r, &wantRoute); err != nil { - t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) - } -} - -func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { - for _, tc := range []struct { - name string - routeNeeded bool - address tcpip.Address - }{ - // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255 - // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff - {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"}, - {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"}, - {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"}, - {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"}, - {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"}, - - // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112] - {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 link-local address starts with fe80::/10. - {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"}, - {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 addresses that are neither multicast nor link-local. - {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - } { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - s.SetRouteTable([]tcpip.Route{}) - - var anyAddr tcpip.Address - if len(tc.address) == header.IPv4AddressSize { - anyAddr = header.IPv4Any - } else { - anyAddr = header.IPv6Any - } - - var want tcpip.Error = &tcpip.ErrNetworkUnreachable{} - if tc.routeNeeded { - want = &tcpip.ErrNoRoute{} - } - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) - } - - if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { - // Route table is empty but we need a route, this should cause an error. - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{}) - } - } else { - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err) - } - if r.LocalAddress != anyAddr { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, anyAddr) - } - if r.RemoteAddress != tc.address { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, tc.address) - } - } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - }) - } -} - -func TestNetworkOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{}, - }) - - opt := tcpip.DefaultTTLOption(5) - if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil { - t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err) - } - - var optGot tcpip.DefaultTTLOption - if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil { - t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err) - } - - if opt != optGot { - t.Errorf("got optGot = %d, want = %d", optGot, opt) - } -} - -func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { - for _, addrLen := range []int{4, 16} { - t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { - for canBe := 0; canBe < 3; canBe++ { - t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { - for never := 0; never < 3; never++ { - t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - // Insert <canBe> primary and <never> never-primary addresses. - // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{}) - for i := 0; i < canBe+never; i++ { - var behavior stack.PrimaryEndpointBehavior - if i < canBe { - behavior = stack.CanBePrimaryEndpoint - } else { - behavior = stack.NeverPrimaryEndpoint - } - // Add an address and in case of a primary one include a - // prefixLen. - address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if behavior == stack.CanBePrimaryEndpoint { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, - } - if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil { - t.Fatal("AddProtocolAddressWithOptions failed:", err) - } - // Remember the address/prefix. - primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} - } else { - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - } - } - // Check that GetMainNICAddress returns an address if at least - // one primary address was added. In that case make sure the - // address/prefixLen matches what we added. - gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) - } - if len(primaryAddrAdded) == 0 { - // No primary addresses present. - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr) - } - } else { - // At least one primary address was added, verify the returned - // address is in the list of primary addresses we added. - if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded) - } - } - }) - } - }) - } - }) - } -} - -func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - for _, tc := range []struct { - name string - address tcpip.Address - prefixLen int - }{ - {"IPv4", "\x01\x01\x01\x01", 24}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, - } { - t.Run(tc.name, func(t *testing.T) { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tc.address, - PrefixLen: tc.prefixLen, - }, - } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) - } - - // Check that we get the right initial address and prefix length. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - - // Check that we get no address after removal. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// Simple network address generator. Good for 255 addresses. -type addressGenerator struct{ cnt byte } - -func (g *addressGenerator) next(addrLen int) tcpip.Address { - g.cnt++ - return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen)) -} - -func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { - t.Helper() - - if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) - } - - sort.Slice(gotAddresses, func(i, j int) bool { - return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address - }) - sort.Slice(expectedAddresses, func(i, j int) bool { - return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address - }) - - for i, gotAddr := range gotAddresses { - expectedAddr := expectedAddresses[i] - if gotAddr != expectedAddr { - t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr) - } - } -} - -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestCreateNICWithOptions(t *testing.T) { - type callArgsAndExpect struct { - nicID tcpip.NICID - opts stack.NICOptions - err tcpip.Error - } - - tests := []struct { - desc string - calls []callArgsAndExpect - }{ - { - desc: "DuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth1"}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth2"}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - { - desc: "DuplicateName", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "lo"}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{Name: "lo"}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - { - desc: "Unnamed", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{}, - err: nil, - }, - }, - }, - { - desc: "UnnamedDuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) - for _, call := range test.calls { - if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { - t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) - } - } - }) - } -} - -func TestNICStats(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) - } - - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) - } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -// TestNICContextPreservation tests that you can read out via stack.NICInfo the -// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. -func TestNICContextPreservation(t *testing.T) { - var ctx *int - tests := []struct { - name string - opts stack.NICOptions - want stack.NICContext - }{ - { - "context_set", - stack.NICOptions{Context: ctx}, - ctx, - }, - { - "context_not_set", - stack.NICOptions{}, - nil, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) - if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { - t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) - } - nicinfos := s.NICInfo() - nicinfo, ok := nicinfos[id] - if !ok { - t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) - } - if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) - } - }) - } -} - -// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local -// addresses. -func TestNICAutoGenLinkLocalAddr(t *testing.T) { - const nicID = 1 - - var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte - n, err := rand.Read(secretKey[:]) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) - } - - nicNameFunc := func(_ tcpip.NICID, name string) string { - return name - } - - tests := []struct { - name string - nicName string - autoGen bool - linkAddr tcpip.LinkAddress - iidOpts ipv6.OpaqueInterfaceIdentifierOptions - shouldGen bool - expectedAddr tcpip.Address - }{ - { - name: "Disabled", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - shouldGen: false, - }, - { - name: "Disabled without OIID options", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: false, - }, - - // Tests for EUI64 based addresses. - { - name: "EUI64 Enabled", - autoGen: true, - linkAddr: linkAddr1, - shouldGen: true, - expectedAddr: header.LinkLocalAddr(linkAddr1), - }, - { - name: "EUI64 Empty MAC", - autoGen: true, - shouldGen: false, - }, - { - name: "EUI64 Invalid MAC", - autoGen: true, - linkAddr: "\x01\x02\x03", - shouldGen: false, - }, - { - name: "EUI64 Multicast MAC", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - shouldGen: false, - }, - { - name: "EUI64 Unspecified MAC", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - shouldGen: false, - }, - - // Tests for Opaque IID based addresses. - { - name: "OIID Enabled", - nicName: "nic1", - autoGen: true, - linkAddr: linkAddr1, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), - }, - // These are all cases where we would not have generated a - // link-local address if opaque IIDs were disabled. - { - name: "OIID Empty MAC and empty nicName", - autoGen: true, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:1], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), - }, - { - name: "OIID Invalid MAC", - nicName: "test", - autoGen: true, - linkAddr: "\x01\x02\x03", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:2], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), - }, - { - name: "OIID Multicast MAC", - nicName: "test2", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:3], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), - }, - { - name: "OIID Unspecified MAC and nil SecretKey", - nicName: "test3", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, - })}, - } - - e := channel.New(0, 1280, test.linkAddr) - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - // A new disabled NIC should not have any address, even if auto generation - // was enabled. - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) - } - - // Enabling the NIC should attempt auto-generation of a link-local - // address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - var expectedMainAddr tcpip.AddressWithPrefix - if test.shouldGen { - expectedMainAddr = tcpip.AddressWithPrefix{ - Address: test.expectedAddr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - } - - // Should have auto-generated an address and resolved immediately (DAD - // is disabled). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } else { - // Should not have auto-generated an address. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address") - default: - } - } - - // Check that we get no address after removal. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { - t.Fatal(err) - } - }) - } -} - -// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are -// not auto-generated for loopback NICs. -func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { - const nicID = 1 - const nicName = "nicName" - - tests := []struct { - name string - opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions - }{ - { - name: "IID From MAC", - opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{}, - }, - { - name: "Opaque IID", - opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, - })}, - } - - e := loopback.New() - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - dadConfigs := stack.DefaultDADConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - } - - e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - linkLocalAddr := header.LinkLocalAddr(linkAddr1) - - // Wait for DAD to resolve. - select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { - t.Fatal(err) - } -} - -// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected -// when an address's kind gets "promoted" to permanent from permanentExpired. -func TestNewPEBOnPromotionToPermanent(t *testing.T) { - pebs := []stack.PrimaryEndpointBehavior{ - stack.NeverPrimaryEndpoint, - stack.CanBePrimaryEndpoint, - stack.FirstPrimaryEndpoint, - } - - for _, pi := range pebs { - for _, ps := range pebs { - t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - // Add a permanent address with initial - // PrimaryEndpointBehavior (peb), pi. If pi is - // NeverPrimaryEndpoint, the address should not - // be returned by a call to GetMainNICAddress; - // else, it should. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - addr, ok := s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) - } - if pi == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) - - } - } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Take a route through the address so its ref - // count gets incremented and does not actually - // get deleted when RemoveAddress is called - // below. This is because we want to test that a - // new peb is respected when an address gets - // "promoted" to permanent from a - // permanentExpired kind. - r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) - if err != nil { - t.Fatalf("FindRoute failed: %v", err) - } - defer r.Release() - if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) - } - - // - // At this point, the address should still be - // known by the NIC, but have its - // kind = permanentExpired. - // - - // Add some other address with peb set to - // FirstPrimaryEndpoint. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) - - } - - // Add back the address we removed earlier and - // make sure the new peb was respected. - // (The address should just be promoted now). - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) - } - var primaryAddrs []tcpip.Address - for _, pa := range s.NICInfo()[1].ProtocolAddresses { - primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) - } - var expectedList []tcpip.Address - switch ps { - case stack.FirstPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x01", - "\x03", - } - case stack.CanBePrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - "\x01", - } - case stack.NeverPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - } - } - if !cmp.Equal(primaryAddrs, expectedList) { - t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) - } - - // Once we remove the other address, if the new - // peb, ps, was NeverPrimaryEndpoint, no address - // should be returned by a call to - // GetMainNICAddress; else, our original address - // should be returned. - if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) - } - addr, ok = s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) - } - if ps == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) - } - } else { - if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) - } - } - }) - } - } -} - -func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { - const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") - ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") - toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - - nicID = 1 - lifetimeSeconds = 9999 - ) - - prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) - - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address - tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address - - // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. - tests := []struct { - name string - slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix - nicAddrs []tcpip.Address - slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix - remoteAddr tcpip.Address - expectedLocalAddr tcpip.Address - }{ - // Test Rule 1 of RFC 6724 section 5 (prefer same address). - { - name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, - remoteAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1}, - remoteAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, - remoteAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - - // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope). - { - name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - - // Test Rule 6 of 6724 section 5 (prefer matching label). - { - name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Toredo most preferred (first address)", - nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1}, - remoteAddr: toredoAddr2, - expectedLocalAddr: toredoAddr1, - }, - { - name: "Toredo most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1}, - remoteAddr: toredoAddr2, - expectedLocalAddr: toredoAddr1, - }, - { - name: "6To4 most preferred (first address)", - nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1}, - remoteAddr: ipv6ToIPv4Addr2, - expectedLocalAddr: ipv6ToIPv4Addr1, - }, - { - name: "6To4 most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1}, - remoteAddr: ipv6ToIPv4Addr2, - expectedLocalAddr: ipv6ToIPv4Addr1, - }, - { - name: "IPv4 mapped IPv6 most preferred (first address)", - nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1}, - remoteAddr: ipv4MappedIPv6Addr2, - expectedLocalAddr: ipv4MappedIPv6Addr1, - }, - { - name: "IPv4 mapped IPv6 most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1}, - remoteAddr: ipv4MappedIPv6Addr2, - expectedLocalAddr: ipv4MappedIPv6Addr1, - }, - - // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses). - { - name: "Temp Global most preferred (last address)", - slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: tempGlobalAddr1, - }, - { - name: "Temp Global most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, - remoteAddr: globalAddr2, - expectedLocalAddr: tempGlobalAddr1, - }, - - // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix). - { - name: "Longest prefix matched most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr2, globalAddr1}, - remoteAddr: globalAddr3, - expectedLocalAddr: globalAddr2, - }, - { - name: "Longest prefix matched most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, globalAddr2}, - remoteAddr: globalAddr3, - expectedLocalAddr: globalAddr2, - }, - - // Test returning the endpoint that is closest to the front when - // candidate addresses are "equal" from the perspective of RFC 6724 - // section 5. - { - name: "Unique Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - remoteAddr: globalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Link Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - remoteAddr: globalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local for Unique Local", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Temp Global for Global", - slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, - slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, - remoteAddr: globalAddr1, - expectedLocalAddr: tempGlobalAddr2, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDispatcher{}, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) - } - - for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) - } - } - - if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) { - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) - } - - if t.Failed() { - t.FailNow() - } - - netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } - - addressableEndpoint, ok := netEP.(stack.AddressableEndpoint) - if !ok { - t.Fatal("network endpoint is not addressable") - } - - addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */) - if addressEP == nil { - t.Fatal("expected a non-nil address endpoint") - } - defer addressEP.DecRef() - - if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr { - t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) - } - }) - } -} - -func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { - const nicID = 1 - broadcastAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) - } - } - - // Enabling the NIC should add the IPv4 broadcast address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if !containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) - } - } - - // Disabling the NIC should remove the IPv4 broadcast address. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) - } - } -} - -// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 -// address after leaving its solicited node multicast address does not result in -// an error. -func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) - } - - // The NIC should have joined addr1's solicited node multicast address. - snmc := header.SolicitedNodeAddr(addr1) - in, err := s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if !in { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) - } - - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) - } - in, err = s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if in { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) - } - - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } -} - -func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - }{ - { - name: "IPv6 All-Nodes", - proto: header.IPv6ProtocolNumber, - addr: header.IPv6AllNodesMulticastAddress, - }, - { - name: "IPv4 All-Systems", - proto: header.IPv4ProtocolNumber, - addr: header.IPv4AllSystems, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - // Should not be in the multicast group yet because the NIC has not been - // enabled yet. - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) - } - - // The multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) - } - - // Leaving the group before disabling the NIC should not cause an error. - if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { - t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) - } - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - }) - } -} - -// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC -// was disabled have DAD performed on them when the NIC is enabled. -func TestDoDADWhenNICEnabled(t *testing.T) { - const dadTransmits = 1 - const retransmitTimer = time.Second - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, - })}, - } - - e := channel.New(dadTransmits, 1280, linkAddr1) - s := stack.New(opts) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 128, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) - } - - // Address should be in the list of all addresses. - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should be tentative so it should not be a main address. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Enabling the NIC should start DAD for the address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Wait for DAD to resolve. - select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - // Enabling the NIC again should be a no-op. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { - t.Fatal(err) - } -} - -func TestStackReceiveBufferSizeOption(t *testing.T) { - const sMin = stack.MinBufferSize - testCases := []struct { - name string - rs stack.ReceiveBufferSizeOption - err tcpip.Error - }{ - // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - - // Valid Configurations - {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - defer s.Close() - if err := s.SetOption(tc.rs); err != tc.err { - t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) - } - var rs stack.ReceiveBufferSizeOption - if tc.err == nil { - if err := s.Option(&rs); err != nil { - t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) - } - if got, want := rs, tc.rs; got != want { - t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) - } - } - }) - } -} - -func TestStackSendBufferSizeOption(t *testing.T) { - const sMin = stack.MinBufferSize - testCases := []struct { - name string - ss tcpip.SendBufferSizeOption - err tcpip.Error - }{ - // Invalid configurations. - {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - - // Valid Configurations - {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - defer s.Close() - err := s.SetOption(tc.ss) - if diff := cmp.Diff(tc.err, err); diff != "" { - t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff) - } - if tc.err == nil { - var ss tcpip.SendBufferSizeOption - if err := s.Option(&ss); err != nil { - t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) - } - if got, want := ss, tc.ss; got != want { - t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) - } - } - }) - } -} - -func TestOutgoingSubnetBroadcast(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID1 = 1 - ) - - defaultAddr := tcpip.AddressWithPrefix{ - Address: header.IPv4Any, - PrefixLen: 0, - } - defaultSubnet := defaultAddr.Subnet() - ipv4Addr := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 24, - } - ipv4Subnet := ipv4Addr.Subnet() - ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") - ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 31, - } - ipv4Subnet31 := ipv4AddrPrefix31.Subnet() - ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() - ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 32, - } - ipv4Subnet32 := ipv4AddrPrefix32.Subnet() - ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() - ipv6Addr := tcpip.AddressWithPrefix{ - Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - PrefixLen: 64, - } - ipv6Subnet := ipv6Addr.Subnet() - ipv6SubnetBcast := ipv6Subnet.Broadcast() - remNetAddr := tcpip.AddressWithPrefix{ - Address: "\x64\x0a\x7b\x18", - PrefixLen: 24, - } - remNetSubnet := remNetAddr.Subnet() - remNetSubnetBcast := remNetSubnet.Broadcast() - - tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedLocalAddress tcpip.Address - expectedRemoteAddress tcpip.Address - expectedRemoteLinkAddress tcpip.LinkAddress - expectedNextHop tcpip.Address - expectedNetProto tcpip.NetworkProtocolNumber - expectedLoop stack.PacketLooping - }{ - // Broadcast to a locally attached subnet populates the broadcast MAC. - { - name: "IPv4 Broadcast to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv4SubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: ipv4SubnetBcast, - expectedRemoteLinkAddress: header.EthernetBroadcastAddress, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut | stack.PacketLoop, - }, - // Broadcast to a locally attached /31 subnet does not populate the - // broadcast MAC. - { - name: "IPv4 Broadcast to local /31 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix31, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet31, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet31Bcast, - expectedLocalAddress: ipv4AddrPrefix31.Address, - expectedRemoteAddress: ipv4Subnet31Bcast, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to a locally attached /32 subnet does not populate the - // broadcast MAC. - { - name: "IPv4 Broadcast to local /32 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix32, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet32, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet32Bcast, - expectedLocalAddress: ipv4AddrPrefix32.Address, - expectedRemoteAddress: ipv4Subnet32Bcast, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // IPv6 has no notion of a broadcast. - { - name: "IPv6 'Broadcast' to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: ipv6Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv6Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv6SubnetBcast, - expectedLocalAddress: ipv6Addr.Address, - expectedRemoteAddress: ipv6SubnetBcast, - expectedNetProto: header.IPv6ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to a remote subnet in the route table is send to the next-hop - // gateway. - { - name: "IPv4 Broadcast to remote subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: remNetSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: remNetSubnetBcast, - expectedNextHop: ipv4Gateway, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to an unknown subnet follows the default route. Note that this - // is essentially just routing an unknown destination IP, because w/o any - // subnet prefix information a subnet broadcast address is just a normal IP. - { - name: "IPv4 Broadcast to unknown subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: defaultSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: remNetSubnetBcast, - expectedNextHop: ipv4Gateway, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID1, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) - } - - s.SetRouteTable(test.routes) - - var netProto tcpip.NetworkProtocolNumber - switch l := len(test.remoteAddr); l { - case header.IPv4AddressSize: - netProto = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - netProto = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } - if r.LocalAddress != test.expectedLocalAddress { - t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress) - } - if r.RemoteAddress != test.expectedRemoteAddress { - t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress) - } - if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { - t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) - } - if r.NextHop != test.expectedNextHop { - t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop) - } - if r.NetProto != test.expectedNetProto { - t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto) - } - if r.Loop != test.expectedLoop { - t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop) - } - }) - } -} - -func TestResolveWith(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address([]byte{192, 168, 1, 58}), - PrefixLen: 24, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) - - remoteAddr := tcpip.Address([]byte{192, 168, 1, 59}) - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) - } - defer r.Release() - - // Should initially require resolution. - if !r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = false, want = true") - } - - // Manually resolving the route should no longer require resolution. - r.ResolveWith("\x01") - if r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = true, want = false") - } -} - -// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its -// associated address is removed should not cause a panic. -func TestRouteReleaseAfterAddrRemoval(t *testing.T) { - const ( - nicID = 1 - localAddr = tcpip.Address("\x01") - remoteAddr = tcpip.Address("\x02") - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err) - } - // Should not panic. - defer r.Release() - - // Check that removing the same address fails. - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err) - } -} - -func TestGetNetworkEndpoint(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - }, - } - - factories := make([]stack.NetworkProtocolFactory, 0, len(tests)) - for _, test := range tests { - factories = append(factories, test.protoFactory) - } - - s := stack.New(stack.Options{ - NetworkProtocols: factories, - }) - - if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep, err := s.GetNetworkEndpoint(nicID, test.protoNum) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err) - } - - if got := ep.NetworkProtocolNumber(); got != test.protoNum { - t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum) - } - }) - } -} - -func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: "\x01", - PrefixLen: 8, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) - } - - // Check that we get the right initial address and prefix length. - if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - // Should still get the address when the NIC is diabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } -} - -// TestAddRoute tests Stack.AddRoute -func TestAddRoute(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{}) - - subnet1, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - - subnet2, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - - expected := []tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet2, Gateway: "\x00", NIC: 1}, - } - - // Initialize the route table with one route. - s.SetRouteTable([]tcpip.Route{expected[0]}) - - // Add another route. - s.AddRoute(expected[1]) - - rt := s.GetRouteTable() - if got, want := len(rt), len(expected); got != want { - t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) - } - for i, route := range rt { - if got, want := route, expected[i]; got != want { - t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) - } - } -} - -// TestRemoveRoutes tests Stack.RemoveRoutes -func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{}) - - addressToRemove := tcpip.Address("\x01") - subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") - if err != nil { - t.Fatal(err) - } - - subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") - if err != nil { - t.Fatal(err) - } - - subnet3, err := tcpip.NewSubnet("\x02", "\x02") - if err != nil { - t.Fatal(err) - } - - // Initialize the route table with three routes. - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet2, Gateway: "\x00", NIC: 1}, - {Destination: subnet3, Gateway: "\x00", NIC: 1}, - }) - - // Remove routes with the specific address. - s.RemoveRoutes(func(r tcpip.Route) bool { - return r.Destination.ID() == addressToRemove - }) - - expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} - rt := s.GetRouteTable() - if got, want := len(rt), len(expected); got != want { - t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) - } - for i, route := range rt { - if got, want := route, expected[i]; got != want { - t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) - } - } -} - -func TestFindRouteWithForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - nic1Addr = tcpip.Address("\x01") - nic2Addr = tcpip.Address("\x02") - remoteAddr = tcpip.Address("\x03") - ) - - type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address - } - - fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, - } - - globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) - globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) - - ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, - } - ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, - } - ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - } - - tests := []struct { - name string - - netCfg netCfg - forwardingEnabled bool - - addrNIC tcpip.NICID - localAddr tcpip.Address - - findRouteErr tcpip.Error - dependentOnForwarding bool - }{ - { - name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: nil, - dependentOnForwarding: true, - }, - { - name: "forwarding disabled and localAddr on specified NIC and route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on specified NIC and route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on same NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on same NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on different NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on different NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: nil, - dependentOnForwarding: true, - }, - { - name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: false, - addrNIC: nicID1, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - addrNIC: nicID1, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and link-local local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and link-local local addr with route on same NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with route on same NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and link-local local addr with route on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and link-local local addr with route on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, - }) - - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) - } - - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) - } - - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) - } - - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) - } - - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { - defer r.Release() - } - if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) - } - - if test.findRouteErr != nil { - return - } - - if r.LocalAddress != test.localAddr { - t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr) - } - if r.RemoteAddress != test.netCfg.remoteAddr { - t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr) - } - - if t.Failed() { - t.FailNow() - } - - // Sending a packet should always go through NIC2 since we only install a - // route to test.netCfg.remoteAddr through NIC2. - data := buffer.View([]byte{1, 2, 3, 4}) - if err := send(r, data); err != nil { - t.Fatalf("send(_, _): %s", err) - } - if n := ep1.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep1", n) - } - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not sent through ep2") - } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) - } - if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) - } - - if !test.forwardingEnabled || !test.dependentOnForwarding { - return - } - - // Disabling forwarding when the route is dependent on forwarding being - // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) - } - { - err := send(r, data) - if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{}) - } - } - if n := ep1.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep1", n) - } - if n := ep2.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep2", n) - } - }) - } -} - -func TestWritePacketToRemote(t *testing.T) { - const nicID = 1 - const MTU = 1280 - e := channel.New(1, MTU, linkAddr1) - s := stack.New(stack.Options{}) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("CreateNIC(%d) = %s", nicID, err) - } - tests := []struct { - name string - protocol tcpip.NetworkProtocolNumber - payload []byte - }{ - { - name: "SuccessIPv4", - protocol: header.IPv4ProtocolNumber, - payload: []byte{1, 2, 3, 4}, - }, - { - name: "SuccessIPv6", - protocol: header.IPv6ProtocolNumber, - payload: []byte{5, 6, 7, 8}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) - } - - pkt, ok := e.Read() - if got, want := ok, true; got != want { - t.Fatalf("e.Read() = %t, want %t", got, want) - } - if got, want := pkt.Proto, test.protocol; got != want { - t.Fatalf("pkt.Proto = %d, want %d", got, want) - } - if pkt.Route.RemoteLinkAddress != linkAddr2 { - t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) - } - if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { - t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) - } - }) - } - - t.Run("InvalidNICID", func(t *testing.T) { - err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()) - if _, ok := err.(*tcpip.ErrUnknownDevice); !ok { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{}) - } - pkt, ok := e.Read() - if got, want := ok, false; got != want { - t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) - } - }) -} - -func TestClearNeighborCacheOnNICDisable(t *testing.T) { - const ( - nicID = 1 - - ipv4Addr = tcpip.Address("\x01\x02\x03\x04") - ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04") - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - ) - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - Clock: clock, - }) - e := channel.New(0, 0, "") - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrs := []struct { - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - }{ - { - proto: ipv4.ProtocolNumber, - addr: ipv4Addr, - }, - { - proto: ipv6.ProtocolNumber, - addr: ipv6Addr, - }, - } - for _, addr := range addrs { - if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) - } - - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, - neighbors, - ); diff != "" { - t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) - } - } - - // Disabling the NIC should clear the neighbor table. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - for _, addr := range addrs { - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if len(neighbors) != 0 { - t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) - } - } - - // Enabling the NIC should have an empty neighbor table. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - for _, addr := range addrs { - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if len(neighbors) != 0 { - t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) - } - } -} - -func TestGetLinkAddressErrors(t *testing.T) { - const ( - nicID = 1 - unknownNICID = nicID + 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - { - err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{}) - } - } - { - err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{}) - } - } -} - -func TestStaticGetLinkAddress(t *testing.T) { - const ( - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - }) - e := channel.New(0, 0, "") - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4", - proto: ipv4.ProtocolNumber, - addr: header.IPv4Broadcast, - expectedLinkAddr: header.EthernetBroadcastAddress, - }, - { - name: "IPv6", - proto: ipv6.ProtocolNumber, - addr: header.IPv6AllNodesMulticastAddress, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ch := make(chan stack.LinkResolutionResult, 1) - if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { - ch <- r - }); err != nil { - t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) - } - - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/tcpip/stack/stack_unsafe_state_autogen.go b/pkg/tcpip/stack/stack_unsafe_state_autogen.go new file mode 100644 index 000000000..758ab3457 --- /dev/null +++ b/pkg/tcpip/stack/stack_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package stack diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go deleted file mode 100644 index 10cbbe589..000000000 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ /dev/null @@ -1,385 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "io/ioutil" - "math" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - testSrcAddrV4 = "\x0a\x00\x00\x01" - testDstAddrV4 = "\x0a\x00\x00\x02" - - testDstPort = 1234 - testSrcPort = 4096 -) - -type testContext struct { - linkEps map[tcpip.NICID]*channel.Endpoint - s *stack.Stack - wq waiter.Queue -} - -// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. -func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - linkEps := make(map[tcpip.NICID]*channel.Endpoint) - for _, linkEpID := range linkEpIDs { - channelEp := channel.New(256, mtu, "") - if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - linkEps[linkEpID] = channelEp - - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) - } - - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) - } - } - - s.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: 1}, - {Destination: header.IPv6EmptySubnet, NIC: 1}, - }) - - return &testContext{ - s: s, - linkEps: linkEps, - } -} - -type headers struct { - srcPort, dstPort uint16 -} - -func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TOS: 0x80, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testSrcAddrV4, - DstAddr: testDstAddrV4, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) -} - -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) -} - -func TestTransportDemuxerRegister(t *testing.T) { - for _, test := range []struct { - name string - proto tcpip.NetworkProtocolNumber - want tcpip.Error - }{ - {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, - {"success", ipv4.ProtocolNumber, nil}, - } { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - tEP, ok := ep.(stack.TransportEndpoint) - if !ok { - t.Fatalf("%T does not implement stack.TransportEndpoint", ep) - } - if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) - } - }) - } -} - -// TestBindToDeviceDistribution injects varied packets on input devices and checks that -// the distribution of packets received matches expectations. -func TestBindToDeviceDistribution(t *testing.T) { - type endpointSockopts struct { - reuse bool - bindToDevice tcpip.NICID - } - for _, test := range []struct { - name string - // endpoints will received the inject packets. - endpoints []endpointSockopts - // wantDistributions is the want ratio of packets received on each - // endpoint for each NIC on which packets are injected. - wantDistributions map[tcpip.NICID][]float64 - }{ - { - "BindPortReuse", - // 5 endpoints that all have reuse set. - []endpointSockopts{ - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - }, - map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed evenly. - 1: {0.2, 0.2, 0.2, 0.2, 0.2}, - }, - }, - { - "BindToDevice", - // 3 endpoints with various bindings. - []endpointSockopts{ - {reuse: false, bindToDevice: 1}, - {reuse: false, bindToDevice: 2}, - {reuse: false, bindToDevice: 3}, - }, - map[tcpip.NICID][]float64{ - // Injected packets on dev0 go only to the endpoint bound to dev0. - 1: {1, 0, 0}, - // Injected packets on dev1 go only to the endpoint bound to dev1. - 2: {0, 1, 0}, - // Injected packets on dev2 go only to the endpoint bound to dev2. - 3: {0, 0, 1}, - }, - }, - { - "ReuseAndBindToDevice", - // 6 endpoints with various bindings. - []endpointSockopts{ - {reuse: true, bindToDevice: 1}, - {reuse: true, bindToDevice: 1}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 0}, - }, - map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed among endpoints bound to - // dev0. - 1: {0.5, 0.5, 0, 0, 0, 0}, - // Injected packets on dev1 get distributed among endpoints bound to - // dev1 or unbound. - 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, - // Injected packets on dev999 go only to the unbound. - 1000: {0, 0, 0, 0, 0, 1}, - }, - }, - } { - for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ - "IPv4": ipv4.ProtocolNumber, - "IPv6": ipv6.ProtocolNumber, - } { - for device, wantDistribution := range test.wantDistributions { - t.Run(test.name+protoName+string(device), func(t *testing.T) { - var devices []tcpip.NICID - for d := range test.wantDistributions { - devices = append(devices, d) - } - c := newDualTestContextMultiNIC(t, defaultMTU, devices) - - eps := make(map[tcpip.Endpoint]int) - - pollChannel := make(chan tcpip.Endpoint) - for i, endpoint := range test.endpoints { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - eps[ep] = i - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(ep) - - defer ep.Close() - ep.SocketOptions().SetReusePort(endpoint.reuse) - if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) - } - - var dstAddr tcpip.Address - switch netProtoNum { - case ipv4.ProtocolNumber: - dstAddr = testDstAddrV4 - case ipv6.ProtocolNumber: - dstAddr = testDstAddrV6 - default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) - } - if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) - } - } - - npackets := 100000 - nports := 10000 - if got, want := len(test.endpoints), len(wantDistribution); got != want { - t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) - } - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - hdrs := &headers{ - srcPort: testSrcPort + port, - dstPort: testDstPort, - } - switch netProtoNum { - case ipv4.ProtocolNumber: - c.sendV4Packet(payload, hdrs, device) - case ipv6.ProtocolNumber: - c.sendV6Packet(payload, hdrs, device) - default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) - } - - ep := <-pollChannel - if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil { - t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled by the same - // socket. - if want, got := ports[port], ep; want != got { - t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) - } - } - } - - // Check that a packet distribution is as expected. - for ep, i := range eps { - wantRatio := wantDistribution[i] - wantRecv := wantRatio * float64(npackets) - actualRecv := stats[ep] - actualRatio := float64(stats[ep]) / float64(npackets) - // The deviation is less than 10%. - if math.Abs(actualRatio-wantRatio) > 0.05 { - t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) - } - } - }) - } - } - } -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go deleted file mode 100644 index bebf4e6b5..000000000 --- a/pkg/tcpip/stack/transport_test.go +++ /dev/null @@ -1,555 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "bytes" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen int = 3 -) - -// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't -// use it. -type fakeTransportEndpoint struct { - stack.TransportEndpointInfo - tcpip.DefaultSocketOptionsHandler - - proto *fakeTransportProtocol - peerAddr tcpip.Address - route *stack.Route - uniqueID uint64 - - // acceptQueue is non-nil iff bound. - acceptQueue []*fakeTransportEndpoint - - // ops is used to set and get socket options. - ops tcpip.SocketOptions -} - -func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { - return &f.TransportEndpointInfo -} - -func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { - return nil -} - -func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} - -func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { - return &f.ops -} - -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} - ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) - return ep -} - -func (f *fakeTransportEndpoint) Abort() { - f.Close() -} - -func (f *fakeTransportEndpoint) Close() { - // TODO(gvisor.dev/issue/5153): Consider retaining the route. - f.route.Release() -} - -func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return mask -} - -func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { - return tcpip.ReadResult{}, nil -} - -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if len(f.route.RemoteAddress) == 0 { - return 0, &tcpip.ErrNoRoute{} - } - - v := make([]byte, p.Len()) - if _, err := io.ReadFull(p, v); err != nil { - return 0, &tcpip.ErrBadBuffer{} - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, - Data: buffer.View(v).ToVectorisedView(), - }) - _ = pkt.TransportHeader().Push(fakeTransHeaderLen) - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { - return 0, err - } - - return int64(len(v)), nil -} - -// SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { - return -1, &tcpip.ErrUnknownProtocolOption{} -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error { - f.peerAddr = addr.Addr - - // Find the route. - r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return &tcpip.ErrNoRoute{} - } - - // Try to register so that we can start receiving packets. - f.ID.RemoteAddress = addr.Addr - err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) - if err != nil { - r.Release() - return err - } - - f.route = r - - return nil -} - -func (f *fakeTransportEndpoint) UniqueID() uint64 { - return f.uniqueID -} - -func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Reset() { -} - -func (*fakeTransportEndpoint) Listen(int) tcpip.Error { - return nil -} - -func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { - if len(f.acceptQueue) == 0 { - return nil, nil, nil - } - a := f.acceptQueue[0] - f.acceptQueue = f.acceptQueue[1:] - return a, nil, nil -} - -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error { - if err := f.proto.stack.RegisterTransportEndpoint( - []tcpip.NetworkProtocolNumber{fakeNetNumber}, - fakeTransNumber, - stack.TransportEndpointID{LocalAddress: a.Addr}, - f, - ports.Flags{}, - 0, /* bindtoDevice */ - ); err != nil { - return err - } - f.acceptQueue = []*fakeTransportEndpoint{} - return nil -} - -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - // Increment the number of received packets. - f.proto.packetCount++ - if f.acceptQueue == nil { - return - } - - netHdr := pkt.NetworkHeader().View() - route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) - if err != nil { - return - } - - ep := &fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: route.RemoteAddress, - route: route, - } - ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) - f.acceptQueue = append(f.acceptQueue, ep) -} - -func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) { - // Increment the number of received control packets. - f.proto.controlCount++ -} - -func (*fakeTransportEndpoint) State() uint32 { - return 0 -} - -func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {} - -func (*fakeTransportEndpoint) Resume(*stack.Stack) {} - -func (*fakeTransportEndpoint) Wait() {} - -func (*fakeTransportEndpoint) LastError() tcpip.Error { - return nil -} - -type fakeTransportGoodOption bool - -type fakeTransportBadOption bool - -type fakeTransportInvalidValueOption int - -type fakeTransportProtocolOptions struct { - good bool -} - -// fakeTransportProtocol is a transport-layer protocol descriptor. It -// aggregates the number of packets received via endpoints of this protocol. -type fakeTransportProtocol struct { - stack *stack.Stack - - packetCount int - controlCount int - opts fakeTransportProtocolOptions -} - -func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { - return fakeTransNumber -} - -func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack), nil -} - -func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return nil, &tcpip.ErrUnknownProtocol{} -} - -func (*fakeTransportProtocol) MinimumPacketSize() int { - return fakeTransHeaderLen -} - -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) { - return 0, 0, nil -} - -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - return stack.UnknownDestinationPacketHandled -} - -func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.TCPModerateReceiveBufferOption: - f.opts.good = bool(*v) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.TCPModerateReceiveBufferOption: - *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -// Abort implements TransportProtocol.Abort. -func (*fakeTransportProtocol) Abort() {} - -// Close implements tcpip.Endpoint.Close. -func (*fakeTransportProtocol) Close() {} - -// Wait implements TransportProtocol.Wait. -func (*fakeTransportProtocol) Wait() {} - -// Parse implements TransportProtocol.Parse. -func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) - return ok -} - -func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { - return &fakeTransportProtocol{stack: s} -} - -func TestTransportReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the packet. - buf := buffer.NewView(30) - - // Make sure packet with wrong protocol is not delivered. - buf[0] = 1 - buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[0] = 1 - buf[1] = 3 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet is delivered. - buf[0] = 1 - buf[1] = 2 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 1 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) - } -} - -func TestTransportControlReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the control packet. - buf := buffer.NewView(2*fakeNetHeaderLen + 30) - - // Outer packet contains the control protocol number. - buf[0] = 1 - buf[1] = 0xfe - buf[2] = uint8(fakeControlProtocol) - - // Make sure packet with wrong protocol is not delivered. - buf[fakeNetHeaderLen+0] = 0 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[fakeNetHeaderLen+0] = 3 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet is delivered. - buf[fakeNetHeaderLen+0] = 2 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 1 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) - } -} - -func TestTransportSend(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Create endpoint and bind it. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Create buffer that will hold the payload. - b := make([]byte, 30) - var r bytes.Reader - r.Reset(b) - if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("write failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - if fakeNet.sendPacketCount[2] != 1 { - t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) - } -} - -func TestTransportOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - - v := tcpip.TCPModerateReceiveBufferOption(true) - if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) - } - v = false - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) - } - if !v { - t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") - } -} diff --git a/pkg/tcpip/stack/tuple_list.go b/pkg/tcpip/stack/tuple_list.go new file mode 100644 index 000000000..31d0feefa --- /dev/null +++ b/pkg/tcpip/stack/tuple_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type tupleElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (tupleElementMapper) linkerFor(elem *tuple) *tuple { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type tupleList struct { + head *tuple + tail *tuple +} + +// Reset resets list l to the empty state. +func (l *tupleList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *tupleList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *tupleList) Front() *tuple { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *tupleList) Back() *tuple { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *tupleList) Len() (count int) { + for e := l.Front(); e != nil; e = (tupleElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *tupleList) PushFront(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + tupleElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *tupleList) PushBack(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + tupleElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *tupleList) PushBackList(m *tupleList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + tupleElementMapper{}.linkerFor(l.tail).SetNext(m.head) + tupleElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *tupleList) InsertAfter(b, e *tuple) { + bLinker := tupleElementMapper{}.linkerFor(b) + eLinker := tupleElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + tupleElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *tupleList) InsertBefore(a, e *tuple) { + aLinker := tupleElementMapper{}.linkerFor(a) + eLinker := tupleElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + tupleElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *tupleList) Remove(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + tupleElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + tupleElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type tupleEntry struct { + next *tuple + prev *tuple +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *tupleEntry) Next() *tuple { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *tupleEntry) Prev() *tuple { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *tupleEntry) SetNext(elem *tuple) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *tupleEntry) SetPrev(elem *tuple) { + e.prev = elem +} diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go new file mode 100644 index 000000000..e628e662d --- /dev/null +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -0,0 +1,1151 @@ +// automatically generated by stateify. + +package tcpip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *ErrAborted) StateTypeName() string { + return "pkg/tcpip.ErrAborted" +} + +func (e *ErrAborted) StateFields() []string { + return []string{} +} + +func (e *ErrAborted) beforeSave() {} + +func (e *ErrAborted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAborted) afterLoad() {} + +func (e *ErrAborted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAddressFamilyNotSupported) StateTypeName() string { + return "pkg/tcpip.ErrAddressFamilyNotSupported" +} + +func (e *ErrAddressFamilyNotSupported) StateFields() []string { + return []string{} +} + +func (e *ErrAddressFamilyNotSupported) beforeSave() {} + +func (e *ErrAddressFamilyNotSupported) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAddressFamilyNotSupported) afterLoad() {} + +func (e *ErrAddressFamilyNotSupported) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAlreadyBound) StateTypeName() string { + return "pkg/tcpip.ErrAlreadyBound" +} + +func (e *ErrAlreadyBound) StateFields() []string { + return []string{} +} + +func (e *ErrAlreadyBound) beforeSave() {} + +func (e *ErrAlreadyBound) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAlreadyBound) afterLoad() {} + +func (e *ErrAlreadyBound) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAlreadyConnected) StateTypeName() string { + return "pkg/tcpip.ErrAlreadyConnected" +} + +func (e *ErrAlreadyConnected) StateFields() []string { + return []string{} +} + +func (e *ErrAlreadyConnected) beforeSave() {} + +func (e *ErrAlreadyConnected) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAlreadyConnected) afterLoad() {} + +func (e *ErrAlreadyConnected) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAlreadyConnecting) StateTypeName() string { + return "pkg/tcpip.ErrAlreadyConnecting" +} + +func (e *ErrAlreadyConnecting) StateFields() []string { + return []string{} +} + +func (e *ErrAlreadyConnecting) beforeSave() {} + +func (e *ErrAlreadyConnecting) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAlreadyConnecting) afterLoad() {} + +func (e *ErrAlreadyConnecting) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBadAddress) StateTypeName() string { + return "pkg/tcpip.ErrBadAddress" +} + +func (e *ErrBadAddress) StateFields() []string { + return []string{} +} + +func (e *ErrBadAddress) beforeSave() {} + +func (e *ErrBadAddress) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBadAddress) afterLoad() {} + +func (e *ErrBadAddress) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBadBuffer) StateTypeName() string { + return "pkg/tcpip.ErrBadBuffer" +} + +func (e *ErrBadBuffer) StateFields() []string { + return []string{} +} + +func (e *ErrBadBuffer) beforeSave() {} + +func (e *ErrBadBuffer) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBadBuffer) afterLoad() {} + +func (e *ErrBadBuffer) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBadLocalAddress) StateTypeName() string { + return "pkg/tcpip.ErrBadLocalAddress" +} + +func (e *ErrBadLocalAddress) StateFields() []string { + return []string{} +} + +func (e *ErrBadLocalAddress) beforeSave() {} + +func (e *ErrBadLocalAddress) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBadLocalAddress) afterLoad() {} + +func (e *ErrBadLocalAddress) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBroadcastDisabled) StateTypeName() string { + return "pkg/tcpip.ErrBroadcastDisabled" +} + +func (e *ErrBroadcastDisabled) StateFields() []string { + return []string{} +} + +func (e *ErrBroadcastDisabled) beforeSave() {} + +func (e *ErrBroadcastDisabled) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBroadcastDisabled) afterLoad() {} + +func (e *ErrBroadcastDisabled) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrClosedForReceive) StateTypeName() string { + return "pkg/tcpip.ErrClosedForReceive" +} + +func (e *ErrClosedForReceive) StateFields() []string { + return []string{} +} + +func (e *ErrClosedForReceive) beforeSave() {} + +func (e *ErrClosedForReceive) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrClosedForReceive) afterLoad() {} + +func (e *ErrClosedForReceive) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrClosedForSend) StateTypeName() string { + return "pkg/tcpip.ErrClosedForSend" +} + +func (e *ErrClosedForSend) StateFields() []string { + return []string{} +} + +func (e *ErrClosedForSend) beforeSave() {} + +func (e *ErrClosedForSend) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrClosedForSend) afterLoad() {} + +func (e *ErrClosedForSend) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectStarted) StateTypeName() string { + return "pkg/tcpip.ErrConnectStarted" +} + +func (e *ErrConnectStarted) StateFields() []string { + return []string{} +} + +func (e *ErrConnectStarted) beforeSave() {} + +func (e *ErrConnectStarted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectStarted) afterLoad() {} + +func (e *ErrConnectStarted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectionAborted) StateTypeName() string { + return "pkg/tcpip.ErrConnectionAborted" +} + +func (e *ErrConnectionAborted) StateFields() []string { + return []string{} +} + +func (e *ErrConnectionAborted) beforeSave() {} + +func (e *ErrConnectionAborted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectionAborted) afterLoad() {} + +func (e *ErrConnectionAborted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectionRefused) StateTypeName() string { + return "pkg/tcpip.ErrConnectionRefused" +} + +func (e *ErrConnectionRefused) StateFields() []string { + return []string{} +} + +func (e *ErrConnectionRefused) beforeSave() {} + +func (e *ErrConnectionRefused) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectionRefused) afterLoad() {} + +func (e *ErrConnectionRefused) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectionReset) StateTypeName() string { + return "pkg/tcpip.ErrConnectionReset" +} + +func (e *ErrConnectionReset) StateFields() []string { + return []string{} +} + +func (e *ErrConnectionReset) beforeSave() {} + +func (e *ErrConnectionReset) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectionReset) afterLoad() {} + +func (e *ErrConnectionReset) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrDestinationRequired) StateTypeName() string { + return "pkg/tcpip.ErrDestinationRequired" +} + +func (e *ErrDestinationRequired) StateFields() []string { + return []string{} +} + +func (e *ErrDestinationRequired) beforeSave() {} + +func (e *ErrDestinationRequired) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrDestinationRequired) afterLoad() {} + +func (e *ErrDestinationRequired) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrDuplicateAddress) StateTypeName() string { + return "pkg/tcpip.ErrDuplicateAddress" +} + +func (e *ErrDuplicateAddress) StateFields() []string { + return []string{} +} + +func (e *ErrDuplicateAddress) beforeSave() {} + +func (e *ErrDuplicateAddress) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrDuplicateAddress) afterLoad() {} + +func (e *ErrDuplicateAddress) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrDuplicateNICID) StateTypeName() string { + return "pkg/tcpip.ErrDuplicateNICID" +} + +func (e *ErrDuplicateNICID) StateFields() []string { + return []string{} +} + +func (e *ErrDuplicateNICID) beforeSave() {} + +func (e *ErrDuplicateNICID) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrDuplicateNICID) afterLoad() {} + +func (e *ErrDuplicateNICID) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrInvalidEndpointState) StateTypeName() string { + return "pkg/tcpip.ErrInvalidEndpointState" +} + +func (e *ErrInvalidEndpointState) StateFields() []string { + return []string{} +} + +func (e *ErrInvalidEndpointState) beforeSave() {} + +func (e *ErrInvalidEndpointState) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrInvalidEndpointState) afterLoad() {} + +func (e *ErrInvalidEndpointState) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrInvalidOptionValue) StateTypeName() string { + return "pkg/tcpip.ErrInvalidOptionValue" +} + +func (e *ErrInvalidOptionValue) StateFields() []string { + return []string{} +} + +func (e *ErrInvalidOptionValue) beforeSave() {} + +func (e *ErrInvalidOptionValue) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrInvalidOptionValue) afterLoad() {} + +func (e *ErrInvalidOptionValue) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrMalformedHeader) StateTypeName() string { + return "pkg/tcpip.ErrMalformedHeader" +} + +func (e *ErrMalformedHeader) StateFields() []string { + return []string{} +} + +func (e *ErrMalformedHeader) beforeSave() {} + +func (e *ErrMalformedHeader) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrMalformedHeader) afterLoad() {} + +func (e *ErrMalformedHeader) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrMessageTooLong) StateTypeName() string { + return "pkg/tcpip.ErrMessageTooLong" +} + +func (e *ErrMessageTooLong) StateFields() []string { + return []string{} +} + +func (e *ErrMessageTooLong) beforeSave() {} + +func (e *ErrMessageTooLong) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrMessageTooLong) afterLoad() {} + +func (e *ErrMessageTooLong) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNetworkUnreachable) StateTypeName() string { + return "pkg/tcpip.ErrNetworkUnreachable" +} + +func (e *ErrNetworkUnreachable) StateFields() []string { + return []string{} +} + +func (e *ErrNetworkUnreachable) beforeSave() {} + +func (e *ErrNetworkUnreachable) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNetworkUnreachable) afterLoad() {} + +func (e *ErrNetworkUnreachable) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoBufferSpace) StateTypeName() string { + return "pkg/tcpip.ErrNoBufferSpace" +} + +func (e *ErrNoBufferSpace) StateFields() []string { + return []string{} +} + +func (e *ErrNoBufferSpace) beforeSave() {} + +func (e *ErrNoBufferSpace) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoBufferSpace) afterLoad() {} + +func (e *ErrNoBufferSpace) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoPortAvailable) StateTypeName() string { + return "pkg/tcpip.ErrNoPortAvailable" +} + +func (e *ErrNoPortAvailable) StateFields() []string { + return []string{} +} + +func (e *ErrNoPortAvailable) beforeSave() {} + +func (e *ErrNoPortAvailable) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoPortAvailable) afterLoad() {} + +func (e *ErrNoPortAvailable) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoRoute) StateTypeName() string { + return "pkg/tcpip.ErrNoRoute" +} + +func (e *ErrNoRoute) StateFields() []string { + return []string{} +} + +func (e *ErrNoRoute) beforeSave() {} + +func (e *ErrNoRoute) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoRoute) afterLoad() {} + +func (e *ErrNoRoute) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoSuchFile) StateTypeName() string { + return "pkg/tcpip.ErrNoSuchFile" +} + +func (e *ErrNoSuchFile) StateFields() []string { + return []string{} +} + +func (e *ErrNoSuchFile) beforeSave() {} + +func (e *ErrNoSuchFile) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoSuchFile) afterLoad() {} + +func (e *ErrNoSuchFile) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNotConnected) StateTypeName() string { + return "pkg/tcpip.ErrNotConnected" +} + +func (e *ErrNotConnected) StateFields() []string { + return []string{} +} + +func (e *ErrNotConnected) beforeSave() {} + +func (e *ErrNotConnected) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNotConnected) afterLoad() {} + +func (e *ErrNotConnected) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNotPermitted) StateTypeName() string { + return "pkg/tcpip.ErrNotPermitted" +} + +func (e *ErrNotPermitted) StateFields() []string { + return []string{} +} + +func (e *ErrNotPermitted) beforeSave() {} + +func (e *ErrNotPermitted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNotPermitted) afterLoad() {} + +func (e *ErrNotPermitted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNotSupported) StateTypeName() string { + return "pkg/tcpip.ErrNotSupported" +} + +func (e *ErrNotSupported) StateFields() []string { + return []string{} +} + +func (e *ErrNotSupported) beforeSave() {} + +func (e *ErrNotSupported) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNotSupported) afterLoad() {} + +func (e *ErrNotSupported) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrPortInUse) StateTypeName() string { + return "pkg/tcpip.ErrPortInUse" +} + +func (e *ErrPortInUse) StateFields() []string { + return []string{} +} + +func (e *ErrPortInUse) beforeSave() {} + +func (e *ErrPortInUse) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrPortInUse) afterLoad() {} + +func (e *ErrPortInUse) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrQueueSizeNotSupported) StateTypeName() string { + return "pkg/tcpip.ErrQueueSizeNotSupported" +} + +func (e *ErrQueueSizeNotSupported) StateFields() []string { + return []string{} +} + +func (e *ErrQueueSizeNotSupported) beforeSave() {} + +func (e *ErrQueueSizeNotSupported) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrQueueSizeNotSupported) afterLoad() {} + +func (e *ErrQueueSizeNotSupported) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrTimeout) StateTypeName() string { + return "pkg/tcpip.ErrTimeout" +} + +func (e *ErrTimeout) StateFields() []string { + return []string{} +} + +func (e *ErrTimeout) beforeSave() {} + +func (e *ErrTimeout) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrTimeout) afterLoad() {} + +func (e *ErrTimeout) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownDevice) StateTypeName() string { + return "pkg/tcpip.ErrUnknownDevice" +} + +func (e *ErrUnknownDevice) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownDevice) beforeSave() {} + +func (e *ErrUnknownDevice) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownDevice) afterLoad() {} + +func (e *ErrUnknownDevice) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownNICID) StateTypeName() string { + return "pkg/tcpip.ErrUnknownNICID" +} + +func (e *ErrUnknownNICID) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownNICID) beforeSave() {} + +func (e *ErrUnknownNICID) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownNICID) afterLoad() {} + +func (e *ErrUnknownNICID) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownProtocol) StateTypeName() string { + return "pkg/tcpip.ErrUnknownProtocol" +} + +func (e *ErrUnknownProtocol) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownProtocol) beforeSave() {} + +func (e *ErrUnknownProtocol) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownProtocol) afterLoad() {} + +func (e *ErrUnknownProtocol) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownProtocolOption) StateTypeName() string { + return "pkg/tcpip.ErrUnknownProtocolOption" +} + +func (e *ErrUnknownProtocolOption) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownProtocolOption) beforeSave() {} + +func (e *ErrUnknownProtocolOption) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownProtocolOption) afterLoad() {} + +func (e *ErrUnknownProtocolOption) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrWouldBlock) StateTypeName() string { + return "pkg/tcpip.ErrWouldBlock" +} + +func (e *ErrWouldBlock) StateFields() []string { + return []string{} +} + +func (e *ErrWouldBlock) beforeSave() {} + +func (e *ErrWouldBlock) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrWouldBlock) afterLoad() {} + +func (e *ErrWouldBlock) StateLoad(stateSourceObject state.Source) { +} + +func (l *sockErrorList) StateTypeName() string { + return "pkg/tcpip.sockErrorList" +} + +func (l *sockErrorList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *sockErrorList) beforeSave() {} + +func (l *sockErrorList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *sockErrorList) afterLoad() {} + +func (l *sockErrorList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *sockErrorEntry) StateTypeName() string { + return "pkg/tcpip.sockErrorEntry" +} + +func (e *sockErrorEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *sockErrorEntry) beforeSave() {} + +func (e *sockErrorEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *sockErrorEntry) afterLoad() {} + +func (e *sockErrorEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (so *SocketOptions) StateTypeName() string { + return "pkg/tcpip.SocketOptions" +} + +func (so *SocketOptions) StateFields() []string { + return []string{ + "handler", + "broadcastEnabled", + "passCredEnabled", + "noChecksumEnabled", + "reuseAddressEnabled", + "reusePortEnabled", + "keepAliveEnabled", + "multicastLoopEnabled", + "receiveTOSEnabled", + "receiveTClassEnabled", + "receivePacketInfoEnabled", + "hdrIncludedEnabled", + "v6OnlyEnabled", + "quickAckEnabled", + "delayOptionEnabled", + "corkOptionEnabled", + "receiveOriginalDstAddress", + "recvErrEnabled", + "errQueue", + "bindToDevice", + "sendBufferSize", + "linger", + } +} + +func (so *SocketOptions) beforeSave() {} + +func (so *SocketOptions) StateSave(stateSinkObject state.Sink) { + so.beforeSave() + stateSinkObject.Save(0, &so.handler) + stateSinkObject.Save(1, &so.broadcastEnabled) + stateSinkObject.Save(2, &so.passCredEnabled) + stateSinkObject.Save(3, &so.noChecksumEnabled) + stateSinkObject.Save(4, &so.reuseAddressEnabled) + stateSinkObject.Save(5, &so.reusePortEnabled) + stateSinkObject.Save(6, &so.keepAliveEnabled) + stateSinkObject.Save(7, &so.multicastLoopEnabled) + stateSinkObject.Save(8, &so.receiveTOSEnabled) + stateSinkObject.Save(9, &so.receiveTClassEnabled) + stateSinkObject.Save(10, &so.receivePacketInfoEnabled) + stateSinkObject.Save(11, &so.hdrIncludedEnabled) + stateSinkObject.Save(12, &so.v6OnlyEnabled) + stateSinkObject.Save(13, &so.quickAckEnabled) + stateSinkObject.Save(14, &so.delayOptionEnabled) + stateSinkObject.Save(15, &so.corkOptionEnabled) + stateSinkObject.Save(16, &so.receiveOriginalDstAddress) + stateSinkObject.Save(17, &so.recvErrEnabled) + stateSinkObject.Save(18, &so.errQueue) + stateSinkObject.Save(19, &so.bindToDevice) + stateSinkObject.Save(20, &so.sendBufferSize) + stateSinkObject.Save(21, &so.linger) +} + +func (so *SocketOptions) afterLoad() {} + +func (so *SocketOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &so.handler) + stateSourceObject.Load(1, &so.broadcastEnabled) + stateSourceObject.Load(2, &so.passCredEnabled) + stateSourceObject.Load(3, &so.noChecksumEnabled) + stateSourceObject.Load(4, &so.reuseAddressEnabled) + stateSourceObject.Load(5, &so.reusePortEnabled) + stateSourceObject.Load(6, &so.keepAliveEnabled) + stateSourceObject.Load(7, &so.multicastLoopEnabled) + stateSourceObject.Load(8, &so.receiveTOSEnabled) + stateSourceObject.Load(9, &so.receiveTClassEnabled) + stateSourceObject.Load(10, &so.receivePacketInfoEnabled) + stateSourceObject.Load(11, &so.hdrIncludedEnabled) + stateSourceObject.Load(12, &so.v6OnlyEnabled) + stateSourceObject.Load(13, &so.quickAckEnabled) + stateSourceObject.Load(14, &so.delayOptionEnabled) + stateSourceObject.Load(15, &so.corkOptionEnabled) + stateSourceObject.Load(16, &so.receiveOriginalDstAddress) + stateSourceObject.Load(17, &so.recvErrEnabled) + stateSourceObject.Load(18, &so.errQueue) + stateSourceObject.Load(19, &so.bindToDevice) + stateSourceObject.Load(20, &so.sendBufferSize) + stateSourceObject.Load(21, &so.linger) +} + +func (l *LocalSockError) StateTypeName() string { + return "pkg/tcpip.LocalSockError" +} + +func (l *LocalSockError) StateFields() []string { + return []string{ + "info", + } +} + +func (l *LocalSockError) beforeSave() {} + +func (l *LocalSockError) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.info) +} + +func (l *LocalSockError) afterLoad() {} + +func (l *LocalSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.info) +} + +func (s *SockError) StateTypeName() string { + return "pkg/tcpip.SockError" +} + +func (s *SockError) StateFields() []string { + return []string{ + "sockErrorEntry", + "Err", + "Cause", + "Payload", + "Dst", + "Offender", + "NetProto", + } +} + +func (s *SockError) beforeSave() {} + +func (s *SockError) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.sockErrorEntry) + stateSinkObject.Save(1, &s.Err) + stateSinkObject.Save(2, &s.Cause) + stateSinkObject.Save(3, &s.Payload) + stateSinkObject.Save(4, &s.Dst) + stateSinkObject.Save(5, &s.Offender) + stateSinkObject.Save(6, &s.NetProto) +} + +func (s *SockError) afterLoad() {} + +func (s *SockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.sockErrorEntry) + stateSourceObject.Load(1, &s.Err) + stateSourceObject.Load(2, &s.Cause) + stateSourceObject.Load(3, &s.Payload) + stateSourceObject.Load(4, &s.Dst) + stateSourceObject.Load(5, &s.Offender) + stateSourceObject.Load(6, &s.NetProto) +} + +func (f *FullAddress) StateTypeName() string { + return "pkg/tcpip.FullAddress" +} + +func (f *FullAddress) StateFields() []string { + return []string{ + "NIC", + "Addr", + "Port", + } +} + +func (f *FullAddress) beforeSave() {} + +func (f *FullAddress) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.NIC) + stateSinkObject.Save(1, &f.Addr) + stateSinkObject.Save(2, &f.Port) +} + +func (f *FullAddress) afterLoad() {} + +func (f *FullAddress) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.NIC) + stateSourceObject.Load(1, &f.Addr) + stateSourceObject.Load(2, &f.Port) +} + +func (c *ControlMessages) StateTypeName() string { + return "pkg/tcpip.ControlMessages" +} + +func (c *ControlMessages) StateFields() []string { + return []string{ + "HasTimestamp", + "Timestamp", + "HasInq", + "Inq", + "HasTOS", + "TOS", + "HasTClass", + "TClass", + "HasIPPacketInfo", + "PacketInfo", + "HasOriginalDstAddress", + "OriginalDstAddress", + "SockErr", + } +} + +func (c *ControlMessages) beforeSave() {} + +func (c *ControlMessages) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.HasTimestamp) + stateSinkObject.Save(1, &c.Timestamp) + stateSinkObject.Save(2, &c.HasInq) + stateSinkObject.Save(3, &c.Inq) + stateSinkObject.Save(4, &c.HasTOS) + stateSinkObject.Save(5, &c.TOS) + stateSinkObject.Save(6, &c.HasTClass) + stateSinkObject.Save(7, &c.TClass) + stateSinkObject.Save(8, &c.HasIPPacketInfo) + stateSinkObject.Save(9, &c.PacketInfo) + stateSinkObject.Save(10, &c.HasOriginalDstAddress) + stateSinkObject.Save(11, &c.OriginalDstAddress) + stateSinkObject.Save(12, &c.SockErr) +} + +func (c *ControlMessages) afterLoad() {} + +func (c *ControlMessages) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.HasTimestamp) + stateSourceObject.Load(1, &c.Timestamp) + stateSourceObject.Load(2, &c.HasInq) + stateSourceObject.Load(3, &c.Inq) + stateSourceObject.Load(4, &c.HasTOS) + stateSourceObject.Load(5, &c.TOS) + stateSourceObject.Load(6, &c.HasTClass) + stateSourceObject.Load(7, &c.TClass) + stateSourceObject.Load(8, &c.HasIPPacketInfo) + stateSourceObject.Load(9, &c.PacketInfo) + stateSourceObject.Load(10, &c.HasOriginalDstAddress) + stateSourceObject.Load(11, &c.OriginalDstAddress) + stateSourceObject.Load(12, &c.SockErr) +} + +func (l *LinkPacketInfo) StateTypeName() string { + return "pkg/tcpip.LinkPacketInfo" +} + +func (l *LinkPacketInfo) StateFields() []string { + return []string{ + "Protocol", + "PktType", + } +} + +func (l *LinkPacketInfo) beforeSave() {} + +func (l *LinkPacketInfo) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Protocol) + stateSinkObject.Save(1, &l.PktType) +} + +func (l *LinkPacketInfo) afterLoad() {} + +func (l *LinkPacketInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Protocol) + stateSourceObject.Load(1, &l.PktType) +} + +func (l *LingerOption) StateTypeName() string { + return "pkg/tcpip.LingerOption" +} + +func (l *LingerOption) StateFields() []string { + return []string{ + "Enabled", + "Timeout", + } +} + +func (l *LingerOption) beforeSave() {} + +func (l *LingerOption) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Enabled) + stateSinkObject.Save(1, &l.Timeout) +} + +func (l *LingerOption) afterLoad() {} + +func (l *LingerOption) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Enabled) + stateSourceObject.Load(1, &l.Timeout) +} + +func (i *IPPacketInfo) StateTypeName() string { + return "pkg/tcpip.IPPacketInfo" +} + +func (i *IPPacketInfo) StateFields() []string { + return []string{ + "NIC", + "LocalAddr", + "DestinationAddr", + } +} + +func (i *IPPacketInfo) beforeSave() {} + +func (i *IPPacketInfo) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.NIC) + stateSinkObject.Save(1, &i.LocalAddr) + stateSinkObject.Save(2, &i.DestinationAddr) +} + +func (i *IPPacketInfo) afterLoad() {} + +func (i *IPPacketInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.NIC) + stateSourceObject.Load(1, &i.LocalAddr) + stateSourceObject.Load(2, &i.DestinationAddr) +} + +func init() { + state.Register((*ErrAborted)(nil)) + state.Register((*ErrAddressFamilyNotSupported)(nil)) + state.Register((*ErrAlreadyBound)(nil)) + state.Register((*ErrAlreadyConnected)(nil)) + state.Register((*ErrAlreadyConnecting)(nil)) + state.Register((*ErrBadAddress)(nil)) + state.Register((*ErrBadBuffer)(nil)) + state.Register((*ErrBadLocalAddress)(nil)) + state.Register((*ErrBroadcastDisabled)(nil)) + state.Register((*ErrClosedForReceive)(nil)) + state.Register((*ErrClosedForSend)(nil)) + state.Register((*ErrConnectStarted)(nil)) + state.Register((*ErrConnectionAborted)(nil)) + state.Register((*ErrConnectionRefused)(nil)) + state.Register((*ErrConnectionReset)(nil)) + state.Register((*ErrDestinationRequired)(nil)) + state.Register((*ErrDuplicateAddress)(nil)) + state.Register((*ErrDuplicateNICID)(nil)) + state.Register((*ErrInvalidEndpointState)(nil)) + state.Register((*ErrInvalidOptionValue)(nil)) + state.Register((*ErrMalformedHeader)(nil)) + state.Register((*ErrMessageTooLong)(nil)) + state.Register((*ErrNetworkUnreachable)(nil)) + state.Register((*ErrNoBufferSpace)(nil)) + state.Register((*ErrNoPortAvailable)(nil)) + state.Register((*ErrNoRoute)(nil)) + state.Register((*ErrNoSuchFile)(nil)) + state.Register((*ErrNotConnected)(nil)) + state.Register((*ErrNotPermitted)(nil)) + state.Register((*ErrNotSupported)(nil)) + state.Register((*ErrPortInUse)(nil)) + state.Register((*ErrQueueSizeNotSupported)(nil)) + state.Register((*ErrTimeout)(nil)) + state.Register((*ErrUnknownDevice)(nil)) + state.Register((*ErrUnknownNICID)(nil)) + state.Register((*ErrUnknownProtocol)(nil)) + state.Register((*ErrUnknownProtocolOption)(nil)) + state.Register((*ErrWouldBlock)(nil)) + state.Register((*sockErrorList)(nil)) + state.Register((*sockErrorEntry)(nil)) + state.Register((*SocketOptions)(nil)) + state.Register((*LocalSockError)(nil)) + state.Register((*SockError)(nil)) + state.Register((*FullAddress)(nil)) + state.Register((*ControlMessages)(nil)) + state.Register((*LinkPacketInfo)(nil)) + state.Register((*LingerOption)(nil)) + state.Register((*IPPacketInfo)(nil)) +} diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go deleted file mode 100644 index 269081ff8..000000000 --- a/pkg/tcpip/tcpip_test.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import ( - "bytes" - "fmt" - "io" - "net" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestLimitedWriter_Write(t *testing.T) { - var b bytes.Buffer - l := LimitedWriter{ - W: &b, - N: 5, - } - if n, err := l.Write([]byte{0, 1, 2}); err != nil { - t.Errorf("got l.Write(3/5) = (_, %s), want nil", err) - } else if n != 3 { - t.Errorf("got l.Write(3/5) = (%d, _), want 3", n) - } - if n, err := l.Write([]byte{3, 4, 5}); err != io.ErrShortWrite { - t.Errorf("got l.Write(3/2) = (_, %s), want io.ErrShortWrite", err) - } else if n != 2 { - t.Errorf("got l.Write(3/2) = (%d, _), want 2", n) - } - if l.N != 0 { - t.Errorf("got l.N = %d, want 0", l.N) - } - l.N = 1 - if n, err := l.Write([]byte{5}); err != nil { - t.Errorf("got l.Write(1/1) = (_, %s), want nil", err) - } else if n != 1 { - t.Errorf("got l.Write(1/1) = (%d, _), want 1", n) - } - if diff := cmp.Diff(b.Bytes(), []byte{0, 1, 2, 3, 4, 5}); diff != "" { - t.Errorf("%T wrote incorrect data: (-want +got):\n%s", l, diff) - } -} - -func TestSubnetContains(t *testing.T) { - tests := []struct { - s Address - m AddressMask - a Address - want bool - }{ - {"\xa0", "\xf0", "\x90", false}, - {"\xa0", "\xf0", "\xa0", true}, - {"\xa0", "\xf0", "\xa5", true}, - {"\xa0", "\xf0", "\xaf", true}, - {"\xa0", "\xf0", "\xb0", false}, - {"\xa0", "\xf0", "", false}, - {"\xa0", "\xf0", "\xa0\x00", false}, - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - s, err := NewSubnet(tt.s, tt.m) - if err != nil { - t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err) - continue - } - if got := s.Contains(tt.a); got != tt.want { - t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want) - } - } -} - -func TestSubnetBits(t *testing.T) { - tests := []struct { - a AddressMask - want1 int - want0 int - }{ - {"\x00", 0, 8}, - {"\x00\x00", 0, 16}, - {"\x36", 0, 8}, - {"\x5c", 0, 8}, - {"\x5c\x5c", 0, 16}, - {"\x5c\x36", 0, 16}, - {"\x36\x5c", 0, 16}, - {"\x36\x36", 0, 16}, - {"\xff", 8, 0}, - {"\xff\xff", 16, 0}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got1, got0 := s.Bits() - if got1 != tt.want1 || got0 != tt.want0 { - t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0) - } - } -} - -func TestSubnetPrefix(t *testing.T) { - tests := []struct { - a AddressMask - want int - }{ - {"\x00", 0}, - {"\x00\x00", 0}, - {"\x36", 0}, - {"\x86", 1}, - {"\xc5", 2}, - {"\xff\x00", 8}, - {"\xff\x36", 8}, - {"\xff\x8c", 9}, - {"\xff\xc8", 10}, - {"\xff", 8}, - {"\xff\xff", 16}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got := s.Prefix() - if got != tt.want { - t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want) - } - } -} - -func TestSubnetCreation(t *testing.T) { - tests := []struct { - a Address - m AddressMask - want error - }{ - {"\xa0", "\xf0", nil}, - {"\xa0\xa0", "\xf0", errSubnetLengthMismatch}, - {"\xaa", "\xf0", errSubnetAddressMasked}, - {"", "", nil}, - } - for _, tt := range tests { - if _, err := NewSubnet(tt.a, tt.m); err != tt.want { - t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want) - } - } -} - -func TestAddressString(t *testing.T) { - for _, want := range []string{ - // Taken from stdlib. - "2001:db8::123:12:1", - "2001:db8::1", - "2001:db8:0:1:0:1:0:1", - "2001:db8:1:0:1:0:1:0", - "2001::1:0:0:1", - "2001:db8:0:0:1::", - "2001:db8::1:0:0:1", - "2001:db8::a:b:c:d", - - // Leading zeros. - "::1", - // Trailing zeros. - "8::", - // No zeros. - "1:1:1:1:1:1:1:1", - // Longer sequence is after other zeros, but not at the end. - "1:0:0:1::1", - // Longer sequence is at the beginning, shorter sequence is at - // the end. - "::1:1:1:0:0", - // Longer sequence is not at the beginning, shorter sequence is - // at the end. - "1::1:1:0:0", - // Longer sequence is at the beginning, shorter sequence is not - // at the end. - "::1:1:0:0:1", - // Neither sequence is at an end, longer is after shorter. - "1:0:0:1::1", - // Shorter sequence is at the beginning, longer sequence is not - // at the end. - "0:0:1:1::1", - // Shorter sequence is at the beginning, longer sequence is at - // the end. - "0:0:1:1:1::", - // Short sequences at both ends, longer one in the middle. - "0:1:1::1:1:0", - // Short sequences at both ends, longer one in the middle. - "0:1::1:0:0", - // Short sequences at both ends, longer one in the middle. - "0:0:1::1:0", - // Longer sequence surrounded by shorter sequences, but none at - // the end. - "1:0:1::1:0:1", - } { - addr := Address(net.ParseIP(want)) - if got := addr.String(); got != want { - t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want) - } - } -} - -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - -func TestAddressWithPrefixSubnet(t *testing.T) { - tests := []struct { - addr Address - prefixLen int - subnetAddr Address - subnetMask AddressMask - }{ - {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, - {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, - {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - } - for _, tt := range tests { - ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} - gotSubnet := ap.Subnet() - wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) - if err != nil { - t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) - continue - } - if gotSubnet != wantSubnet { - t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) - } - } -} - -func TestAddressUnspecified(t *testing.T) { - tests := []struct { - addr Address - unspecified bool - }{ - { - addr: "", - unspecified: true, - }, - { - addr: "\x00", - unspecified: true, - }, - { - addr: "\x01", - unspecified: false, - }, - { - addr: "\x00\x00", - unspecified: true, - }, - { - addr: "\x01\x00", - unspecified: false, - }, - { - addr: "\x00\x01", - unspecified: false, - }, - { - addr: "\x01\x01", - unspecified: false, - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) { - if got := test.addr.Unspecified(); got != test.unspecified { - t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified) - } - }) - } -} - -func TestAddressMatchingPrefix(t *testing.T) { - tests := []struct { - addrA Address - addrB Address - prefix uint8 - }{ - { - addrA: "\x01\x01", - addrB: "\x01\x01", - prefix: 16, - }, - { - addrA: "\x01\x01", - addrB: "\x01\x00", - prefix: 15, - }, - { - addrA: "\x01\x01", - addrB: "\x81\x00", - prefix: 0, - }, - { - addrA: "\x01\x01", - addrB: "\x01\x80", - prefix: 8, - }, - { - addrA: "\x01\x01", - addrB: "\x02\x80", - prefix: 6, - }, - } - - for _, test := range tests { - if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix { - t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix) - } - } -} diff --git a/pkg/tcpip/tcpip_unsafe_state_autogen.go b/pkg/tcpip/tcpip_unsafe_state_autogen.go new file mode 100644 index 000000000..8f6fc08cf --- /dev/null +++ b/pkg/tcpip/tcpip_unsafe_state_autogen.go @@ -0,0 +1,33 @@ +// automatically generated by stateify. + +// +build go1.9 +// +build !go1.17 + +package tcpip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *StdClock) StateTypeName() string { + return "pkg/tcpip.StdClock" +} + +func (s *StdClock) StateFields() []string { + return []string{} +} + +func (s *StdClock) beforeSave() {} + +func (s *StdClock) StateSave(stateSinkObject state.Sink) { + s.beforeSave() +} + +func (s *StdClock) afterLoad() {} + +func (s *StdClock) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*StdClock)(nil)) +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD deleted file mode 100644 index 58aabe547..000000000 --- a/pkg/tcpip/tests/integration/BUILD +++ /dev/null @@ -1,129 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "forward_test", - size = "small", - srcs = ["forward_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/checker", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "iptables_test", - size = "small", - srcs = ["iptables_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/transport/udp", - ], -) - -go_test( - name = "link_resolution_test", - size = "small", - srcs = ["link_resolution_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/pipe", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) - -go_test( - name = "loopback_test", - size = "small", - srcs = ["loopback_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "multicast_broadcast_test", - size = "small", - srcs = ["multicast_broadcast_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "route_test", - size = "small", - srcs = ["route_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go deleted file mode 100644 index 0cb9d034e..000000000 --- a/pkg/tcpip/tests/integration/forward_test.go +++ /dev/null @@ -1,299 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package forward_test - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestForwarding(t *testing.T) { - const listenPort = 8080 - - type endpointAndAddresses struct { - serverEP tcpip.Endpoint - serverAddr tcpip.Address - serverReadableCH chan struct{} - - clientEP tcpip.Endpoint - clientAddr tcpip.Address - clientReadableCH chan struct{} - } - - newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { - t.Helper() - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := s.NewEndpoint(transProto, netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) - } - - t.Cleanup(func() { - wq.EventUnregister(&we) - }) - - return ep, ch - } - - tests := []struct { - name string - epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses - }{ - { - name: "IPv4 host1 server with host2 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv6 host2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv4 host2 server with routerNIC1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv6 routerNIC2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - } - - subTests := []struct { - name string - proto tcpip.TransportProtocolNumber - expectedConnectErr tcpip.Error - setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) - needRemoteAddr bool - }{ - { - name: "UDP", - proto: udp.ProtocolNumber, - expectedConnectErr: nil, - setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - if err := ep.Connect(clientAddr); err != nil { - t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) - } - return nil, nil - }, - needRemoteAddr: true, - }, - { - name: "TCP", - proto: tcp.ProtocolNumber, - expectedConnectErr: &tcpip.ErrConnectStarted{}, - setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - if err := ep.Listen(1); err != nil { - t.Fatalf("ep.Listen(1): %s", err) - } - var addr tcpip.FullAddress - for { - newEP, wq, err := ep.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue - } - if err != nil { - t.Fatalf("ep.Accept(_): %s", err) - } - if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( - "NIC", - )); diff != "" { - t.Errorf("accepted address mismatch (-want +got):\n%s", diff) - } - - we, newCH := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - return newEP, newCH - } - }, - needRemoteAddr: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - } - - host1Stack := stack.New(stackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - - epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) - defer epsAndAddrs.serverEP.Close() - defer epsAndAddrs.clientEP.Close() - - serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} - if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { - t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) - } - clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} - if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { - t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) - } - - { - err := epsAndAddrs.clientEP.Connect(serverAddr) - if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) - } - } - if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { - t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) - } else { - clientAddr = addr - clientAddr.NIC = 0 - } - - serverEP := epsAndAddrs.serverEP - serverCH := epsAndAddrs.serverReadableCH - if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil { - defer ep.Close() - serverEP = ep - serverCH = ch - } - - write := func(ep tcpip.Endpoint, data []byte) { - t.Helper() - - var r bytes.Reader - r.Reset(data) - var wOpts tcpip.WriteOptions - n, err := ep.Write(&r, wOpts) - if err != nil { - t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) - } - } - - data := []byte{1, 2, 3, 4} - write(epsAndAddrs.clientEP, data) - - read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { - t.Helper() - - // Wait for the endpoint to be readable. - <-ch - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) - } - - readResult := tcpip.ReadResult{ - Count: len(data), - Total: len(data), - } - if subTest.needRemoteAddr { - readResult.RemoteAddr = expectedFrom - } - if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes(), data); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - if t.Failed() { - t.FailNow() - } - } - - read(serverCH, serverEP, data, clientAddr) - - data = []byte{5, 6, 7, 8, 9, 10, 11, 12} - write(serverEP, data) - read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go deleted file mode 100644 index 480174070..000000000 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ /dev/null @@ -1,647 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -type inputIfNameMatcher struct { - name string -} - -var _ stack.Matcher = (*inputIfNameMatcher)(nil) - -func (*inputIfNameMatcher) Name() string { - return "inputIfNameMatcher" -} - -func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { - return (hook == stack.Input && im.name != "" && im.name == inNicName), false -} - -const ( - nicID = 1 - nicName = "nic1" - anotherNicName = "nic2" - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - srcAddrV4 = "\x0a\x00\x00\x01" - dstAddrV4 = "\x0a\x00\x00\x02" - srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - payloadSize = 20 -) - -func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { - t.Helper() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) - } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) - } - return s, e -} - -func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { - t.Helper() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - e := channel.New(0, header.IPv4MinimumMTU, linkAddr) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) - } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) - } - return s, e -} - -func genPacketV6() *stack.PacketBuffer { - pktSize := header.IPv6MinimumSize + payloadSize - hdr := buffer.NewPrependable(pktSize) - ip := header.IPv6(hdr.Prepend(pktSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadSize, - TransportProtocol: 99, - HopLimit: 255, - SrcAddr: srcAddrV6, - DstAddr: dstAddrV6, - }) - vv := hdr.View().ToVectorisedView() - return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) -} - -func genPacketV4() *stack.PacketBuffer { - pktSize := header.IPv4MinimumSize + payloadSize - hdr := buffer.NewPrependable(pktSize) - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&header.IPv4Fields{ - TOS: 0, - TotalLength: uint16(pktSize), - ID: 1, - Flags: 0, - FragmentOffset: 16, - TTL: 48, - Protocol: 99, - SrcAddr: srcAddrV4, - DstAddr: dstAddrV4, - }) - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - vv := hdr.View().ToVectorisedView() - return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) -} - -func TestIPTablesStatsForInput(t *testing.T) { - tests := []struct { - name string - setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) - setupFilter func(*testing.T, *stack.Stack) - genPacket func() *stack.PacketBuffer - proto tcpip.NetworkProtocolNumber - expectReceived int - expectInputDropped int - }{ - { - name: "IPv6 Accept", - setupStack: genStackV6, - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept", - setupStack: genStackV4, - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv6 Drop (input interface matches)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv4 Drop (input interface matches)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv6 Accept (input interface does not match)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept (input interface does not match)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv6 Drop (input interface does not match but invert is true)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ - InputInterface: anotherNicName, - InputInterfaceInvert: true, - } - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv4 Drop (input interface does not match but invert is true)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ - InputInterface: anotherNicName, - InputInterfaceInvert: true, - } - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv6 Accept (input interface does not match using a matcher)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept (input interface does not match using a matcher)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := test.setupStack(t) - test.setupFilter(t, s) - e.InjectInbound(test.proto, test.genPacket()) - - if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { - t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) - } - if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { - t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) - } - }) - } -} - -var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil) - -// channelEndpointWithoutWritePacket is a channel endpoint that does not support -// stack.LinkEndpoint.WritePacket. -type channelEndpointWithoutWritePacket struct { - *channel.Endpoint - - t *testing.T -} - -func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") - return &tcpip.ErrNotSupported{} -} - -var _ stack.Matcher = (*udpSourcePortMatcher)(nil) - -type udpSourcePortMatcher struct { - port uint16 -} - -func (*udpSourcePortMatcher) Name() string { - return "udpSourcePortMatcher" -} - -func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { - udp := header.UDP(pkt.TransportHeader().View()) - if len(udp) < header.UDPMinimumSize { - // Drop immediately as the packet is invalid. - return false, true - } - - return udp.SourcePort() == m.port, false -} - -func TestIPTableWritePackets(t *testing.T) { - const ( - nicID = 1 - - dropLocalPort = utils.LocalPort - 1 - acceptPackets = 2 - dropPackets = 3 - ) - - udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { - u := header.UDP(hdr) - u.Encode(&header.UDPFields{ - SrcPort: srcPort, - DstPort: dstPort, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) - sum = header.Checksum(hdr, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - } - - tests := []struct { - name string - setupFilter func(*testing.T, *stack.Stack) - genPacket func(*stack.Route) stack.PacketBufferList - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectSent uint64 - expectOutputDropped uint64 - }{ - { - name: "IPv4 Accept", - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress, r.RemoteAddress, utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - - return pkts - }, - proto: header.IPv4ProtocolNumber, - remoteAddr: dstAddrV4, - expectSent: 1, - expectOutputDropped: 0, - }, - { - name: "IPv4 Drop Other Port", - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - - table := stack.Table{ - Rules: []stack.Rule{ - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, - Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - } - - if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) - } - }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - for i := 0; i < acceptPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress, r.RemoteAddress, utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - for i := 0; i < dropPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - - return pkts - }, - proto: header.IPv4ProtocolNumber, - remoteAddr: dstAddrV4, - expectSent: acceptPackets, - expectOutputDropped: dropPackets, - }, - { - name: "IPv6 Accept", - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress, r.RemoteAddress, utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - - return pkts - }, - proto: header.IPv6ProtocolNumber, - remoteAddr: dstAddrV6, - expectSent: 1, - expectOutputDropped: 0, - }, - { - name: "IPv6 Drop Other Port", - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - - table := stack.Table{ - Rules: []stack.Rule{ - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, - Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - } - - if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) - } - }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - for i := 0; i < acceptPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress, r.RemoteAddress, utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - for i := 0; i < dropPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - - return pkts - }, - proto: header.IPv6ProtocolNumber, - remoteAddr: dstAddrV6, - expectSent: acceptPackets, - expectOutputDropped: dropPackets, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channelEndpointWithoutWritePacket{ - Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), - t: t, - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) - } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - test.setupFilter(t, s) - - r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) - } - defer r.Release() - - pkts := test.genPacket(r) - pktsLen := pkts.Len() - if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{ - Protocol: header.UDPProtocolNumber, - TTL: 64, - }); err != nil { - t.Fatalf("WritePackets(...): %s", err) - } else if n != pktsLen { - t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen) - } - - if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { - t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) - } - if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { - t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go deleted file mode 100644 index 18da67fb1..000000000 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ /dev/null @@ -1,1289 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package link_resolution_test - -import ( - "bytes" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/pipe" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) { - host1Stack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - - host1NIC, host2NIC := pipe.New(utils.LinkAddr1, utils.LinkAddr2) - - if err := host1Stack.CreateNIC(host1NICID, utils.NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := host2Stack.CreateNIC(host2NICID, utils.NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: utils.Ipv4Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: utils.Ipv6Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: utils.Ipv4Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: utils.Ipv6Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - }) - - return host1Stack, host2Stack -} - -// TestPing tests that two hosts can ping eachother when link resolution is -// enabled. -func TestPing(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - - // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo - // request/reply packets. - icmpDataOffset = 8 - ) - - tests := []struct { - name string - transProto tcpip.TransportProtocolNumber - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - icmpBuf func(*testing.T) []byte - }{ - { - name: "IPv4 Ping", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) []byte { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) - hdr.SetType(header.ICMPv4Echo) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return hdr - }, - }, - { - name: "IPv6 Ping", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) []byte { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) - hdr.SetType(header.ICMPv6EchoRequest) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return hdr - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - - var wq waiter.Queue - we, waiterCH := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - icmpBuf := test.icmpBuf(t) - var r bytes.Reader - r.Reset(icmpBuf) - wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(_, _): %s", err) - } else if want := int64(len(icmpBuf)); n != want { - t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) - } - - // Wait for the endpoint to be readable. - <-waiterCH - - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type transportError struct { - origin tcpip.SockErrOrigin - typ uint8 - code uint8 - info uint32 - kind stack.TransportErrorKind -} - -func TestTCPLinkResolutionFailure(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedWriteErr tcpip.Error - sockError tcpip.SockError - transErr transportError - }{ - { - name: "IPv4 with resolvable remote", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv6 with resolvable remote", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv4 without resolvable remote", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: &tcpip.ErrNoRoute{}, - sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - Dst: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv4Addr3.AddressWithPrefix.Address, - Port: 1234, - }, - Offender: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv4Addr1.AddressWithPrefix.Address, - }, - NetProto: ipv4.ProtocolNumber, - }, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4HostUnreachable), - kind: stack.DestinationHostUnreachableTransportError, - }, - }, - { - name: "IPv6 without resolvable remote", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: &tcpip.ErrNoRoute{}, - sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - Dst: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv6Addr3.AddressWithPrefix.Address, - Port: 1234, - }, - Offender: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv6Addr1.AddressWithPrefix.Address, - }, - NetProto: ipv6.ProtocolNumber, - }, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6DstUnreachable), - code: uint8(header.ICMPv6AddressUnreachable), - kind: stack.DestinationHostUnreachableTransportError, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - } - - host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) - - var listenerWQ waiter.Queue - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) - } - defer listenerEP.Close() - - listenerAddr := tcpip.FullAddress{Port: 1234} - if err := listenerEP.Bind(listenerAddr); err != nil { - t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) - } - - if err := listenerEP.Listen(1); err != nil { - t.Fatalf("listenerEP.Listen(1): %s", err) - } - - var clientWQ waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&we, waiter.EventOut|waiter.EventErr) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) - } - defer clientEP.Close() - - sockOpts := clientEP.SocketOptions() - sockOpts.SetRecvError(true) - - remoteAddr := listenerAddr - remoteAddr.Addr = test.remoteAddr - { - err := clientEP.Connect(remoteAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) - } - } - - // Wait for an error due to link resolution failing, or the endpoint to be - // writable. - <-ch - { - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - _, err := clientEP.Write(&r, wOpts) - if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { - t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) - } - } - - if test.expectedWriteErr == nil { - return - } - - sockErr := sockOpts.DequeueErr() - if sockErr == nil { - t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil") - } - - sockErrCmpOpts := []cmp.Option{ - cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b tcpip.Error) bool { - // tcpip.Error holds an unexported field but the errors netstack uses - // are pre defined so we can simply compare pointers. - return a == b - }), - checker.IgnoreCmpPath( - // Ignore the payload since we do not know the TCP seq/ack numbers. - "Payload", - // Ignore the cause since we will compare its properties separately - // since the concrete type of the cause is unknown. - "Cause", - ), - } - - if addr, err := clientEP.GetLocalAddress(); err != nil { - t.Fatalf("clientEP.GetLocalAddress(): %s", err) - } else { - test.sockError.Offender.Port = addr.Port - } - if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { - t.Errorf("socket error mismatch (-want +got):\n%s", diff) - } - - transErr, ok := sockErr.Cause.(stack.TransportError) - if !ok { - t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause) - } - if diff := cmp.Diff( - test.transErr, - transportError{ - origin: transErr.Origin(), - typ: transErr.Type(), - code: transErr.Code(), - info: transErr.Info(), - kind: transErr.Kind(), - }, - cmp.AllowUnexported(transportError{}), - ); diff != "" { - t.Errorf("socket error mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestGetLinkAddress(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedOk bool - }{ - { - name: "IPv4 resolvable", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedOk: true, - }, - { - name: "IPv6 resolvable", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedOk: true, - }, - { - name: "IPv4 not resolvable", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedOk: false, - }, - { - name: "IPv6 not resolvable", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedOk: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - wantRes := stack.LinkResolutionResult{Success: test.expectedOk} - if test.expectedOk { - wantRes.LinkAddress = utils.LinkAddr2 - } - if diff := cmp.Diff(wantRes, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestRouteResolvedFields(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - localAddr tcpip.Address - remoteAddr tcpip.Address - immediatelyResolvable bool - expectedSuccess bool - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4 immediately resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: header.IPv4AllSystems, - immediatelyResolvable: true, - expectedSuccess: true, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems), - }, - { - name: "IPv6 immediately resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: header.IPv6AllNodesMulticastAddress, - immediatelyResolvable: true, - expectedSuccess: true, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - }, - { - name: "IPv4 resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedSuccess: true, - expectedLinkAddr: utils.LinkAddr2, - }, - { - name: "IPv6 resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedSuccess: true, - expectedLinkAddr: utils.LinkAddr2, - }, - { - name: "IPv4 not resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedSuccess: false, - }, - { - name: "IPv6 not resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedSuccess: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - var wantRouteInfo stack.RouteInfo - wantRouteInfo.LocalLinkAddress = utils.LinkAddr1 - wantRouteInfo.LocalAddress = test.localAddr - wantRouteInfo.RemoteAddress = test.remoteAddr - wantRouteInfo.NetProto = test.netProto - wantRouteInfo.Loop = stack.PacketOut - wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr - - ch := make(chan stack.ResolvedFieldsResult, 1) - - if !test.immediatelyResolvable { - wantUnresolvedRouteInfo := wantRouteInfo - wantUnresolvedRouteInfo.RemoteLinkAddress = "" - - err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) - } - - if !test.expectedSuccess { - return - } - - // At this point the neighbor table should be populated so the route - // should be immediately resolvable. - } - - if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }); err != nil { - t.Errorf("r.ResolvedFields(_): %s", err) - } - select { - case routeResolveRes := <-ch: - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected route to be immediately resolvable") - } - }) - } -} - -func TestWritePacketsLinkResolution(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedWriteErr tcpip.Error - }{ - { - name: "IPv4", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv6", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - } - - host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) - - var serverWQ waiter.Queue - serverWE, serverCH := waiter.NewChannelEntry(nil) - serverWQ.EventRegister(&serverWE, waiter.EventIn) - serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) - } - defer serverEP.Close() - - serverAddr := tcpip.FullAddress{Port: 1234} - if err := serverEP.Bind(serverAddr); err != nil { - t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) - } - - r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - data := []byte{1, 2} - var pkts stack.PacketBufferList - for _, d := range data { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), - Data: buffer.View([]byte{d}).ToVectorisedView(), - }) - pkt.TransportProtocolNumber = udp.ProtocolNumber - length := uint16(pkt.Size()) - udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - udpHdr.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: serverAddr.Port, - Length: length, - }) - xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) - - pkts.PushBack(pkt) - } - - params := stack.NetworkHeaderParams{ - Protocol: udp.ProtocolNumber, - TTL: 64, - TOS: stack.DefaultTOS, - } - - if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { - t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) - } else if want := pkts.Len(); want != n { - t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) - } - - var writer bytes.Buffer - count := 0 - for { - var rOpts tcpip.ReadOptions - res, err := serverEP.Read(&writer, rOpts) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Should not have anymore bytes to read after we read the sent - // number of bytes. - if count == len(data) { - break - } - - <-serverCH - continue - } - - t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) - } - count += res.Count - } - - if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { - t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) - } - if diff := cmp.Diff(data, writer.Bytes()); diff != "" { - t.Errorf("read bytes mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type eventType int - -const ( - entryAdded eventType = iota - entryChanged - entryRemoved -) - -func (t eventType) String() string { - switch t { - case entryAdded: - return "add" - case entryChanged: - return "change" - case entryRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -type eventInfo struct { - eventType eventType - nicID tcpip.NICID - entry stack.NeighborEntry -} - -func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) -} - -var _ stack.NUDDispatcher = (*nudDispatcher)(nil) - -type nudDispatcher struct { - c chan eventInfo -} - -func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryChanged, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryRemoved, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) waitForEvent(want eventInfo) error { - if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - return nil -} - -// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it -// that the neighbor used for a route is reachable. -func TestTCPConfirmNeighborReachability(t *testing.T) { - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - neighborAddr tcpip.Address - getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) - isHost1Listener bool - }{ - { - name: "IPv4 active connection through neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - }, - { - name: "IPv6 active connection through neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - }, - { - name: "IPv4 active connection to neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - }, - { - name: "IPv6 active connection to neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - }, - { - name: "IPv4 passive connection to neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv6 passive connection to neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv4 passive connection through neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv6 passive connection through neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventOut) - clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - listenerEP.Close() - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, clientEP, clientCH - }, - isHost1Listener: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - nudDisp := nudDispatcher{ - c: make(chan eventInfo, 3), - } - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - Clock: clock, - } - host1StackOpts := stackOpts - host1StackOpts.NUDDisp = &nudDisp - - host1Stack := stack.New(host1StackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - - // Add a reachable dynamic entry to our neighbor table for the remote. - { - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(utils.Host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", utils.Host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: utils.LinkAddr2, Success: true}, <-ch); diff != "" { - t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) - } - } - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryAdded, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, - }); err != nil { - t.Fatalf("error waiting for initial NUD event: %s", err) - } - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - - // Wait for the remote's neighbor entry to be stale before creating a - // TCP connection from host1 to some remote. - nudConfigs, err := host1Stack.NUDConfigurations(utils.Host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", utils.Host1NICID, test.netProto, err) - } - // The maximum reachable time for a neighbor is some maximum random factor - // applied to the base reachable time. - // - // See NUDConfigurations.BaseReachableTime for more information. - maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) - clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for stale NUD event: %s", err) - } - - listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) - defer listenerEP.Close() - defer clientEP.Close() - listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} - if err := listenerEP.Bind(listenerAddr); err != nil { - t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) - } - if err := listenerEP.Listen(1); err != nil { - t.Fatalf("listenerEP.Listen(1): %s", err) - } - { - err := clientEP.Connect(listenerAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) - } - } - - // Wait for the TCP handshake to complete then make sure the neighbor is - // reachable without entering the probe state as TCP should provide NUD - // with confirmation that the neighbor is reachable (indicated by a - // successful 3-way handshake). - <-clientCH - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for delay NUD event: %s", err) - } - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - - // Wait for the neighbor to be stale again then send data to the remote. - // - // On successful transmission, the neighbor should become reachable - // without probing the neighbor as a TCP ACK would be received which is an - // indication of the neighbor being reachable. - clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for stale NUD event: %s", err) - } - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := clientEP.Write(&r, wOpts); err != nil { - t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) - } - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for delay NUD event: %s", err) - } - if test.isHost1Listener { - // If host1 is not the client, host1 does not send any data so TCP - // has no way to know it is making forward progress. Because of this, - // TCP should not mark the route reachable and NUD should go through the - // probe state. - clock.Advance(nudConfigs.DelayFirstProbeTime) - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for probe NUD event: %s", err) - } - } - if err := nudDisp.waitForEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - }) - } -} - -func TestDAD(t *testing.T) { - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - dadNetProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedResolved bool - }{ - { - name: "IPv4 own address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - expectedResolved: true, - }, - { - name: "IPv6 own address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - expectedResolved: true, - }, - { - name: "IPv4 duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedResolved: false, - }, - { - name: "IPv6 duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedResolved: false, - }, - { - name: "IPv4 no duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedResolved: true, - }, - { - name: "IPv6 no duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedResolved: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ - arp.NewProtocol, - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - } - - host1Stack, _ := setupStack(t, stackOpts, utils.Host1NICID, utils.Host2NICID) - - // DAD should be disabled by default. - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled") - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADDisabled { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled) - } - - // Enable DAD then attempt to check if an address is duplicated. - netEP, err := host1Stack.GetNetworkEndpoint(utils.Host1NICID, test.dadNetProto) - if err != nil { - t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", utils.Host1NICID, test.dadNetProto, err) - } - dad, ok := netEP.(stack.DuplicateAddressDetector) - if !ok { - t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP) - } - dad.SetDADConfigurations(dadConfigs) - ch := make(chan stack.DADResult, 3) - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADStarting { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting) - } - - expectResults := 1 - if test.expectedResolved { - const delta = time.Nanosecond - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r) - default: - } - - // If we expect the resolve to succeed try requesting DAD again on the - // same address. The handler for the new request should be called once - // the original DAD request completes. - expectResults = 2 - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADAlreadyRunning { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning) - } - - clock.Advance(delta) - } - - for i := 0; i < expectResults; i++ { - if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" { - t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) - } - } - - // Should have no more results. - select { - case r := <-ch: - t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go deleted file mode 100644 index c56155ea2..000000000 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ /dev/null @@ -1,504 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package loopback_test - -import ( - "bytes" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) - -type ndpDispatcher struct{} - -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { -} - -func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { - return false -} - -func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} - -func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false -} - -func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} - -func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return true -} - -func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} - -func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {} - -func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {} - -func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {} - -func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {} - -// TestInitialLoopbackAddresses tests that the loopback interface does not -// auto-generate a link-local address when it is brought up. -func TestInitialLoopbackAddresses(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDispatcher{}, - AutoGenLinkLocal: true, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID, nicName string) string { - t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) - return "" - }, - }, - })}, - }) - - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - nicsInfo := s.NICInfo() - if nicInfo, ok := nicsInfo[nicID]; !ok { - t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo) - } else if got := len(nicInfo.ProtocolAddresses); got != 0 { - t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses) - } -} - -// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address and UDP -// traffic is sent/received correctly. -func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { - const ( - nicID = 1 - localPort = 80 - ) - - data := []byte{1, 2, 3, 4} - - ipv4ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) - ipv4Bytes[len(ipv4Bytes)-1]++ - otherIPv4Address := tcpip.Address(ipv4Bytes) - - ipv6ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - ipv6Bytes := []byte(utils.Ipv6Addr.Address) - ipv6Bytes[len(ipv6Bytes)-1]++ - otherIPv6Address := tcpip.Address(ipv6Bytes) - - tests := []struct { - name string - addAddress tcpip.ProtocolAddress - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool - }{ - { - name: "IPv4 bind to wildcard and send to assigned address", - addAddress: ipv4ProtocolAddress, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectRx: true, - }, - { - name: "IPv4 bind to wildcard and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - dstAddr: otherIPv4Address, - expectRx: true, - }, - { - name: "IPv4 bind to wildcard send to other address", - addAddress: ipv4ProtocolAddress, - dstAddr: utils.RemoteIPv4Addr, - expectRx: false, - }, - { - name: "IPv4 bind to other subnet-local address and send to assigned address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectRx: false, - }, - { - name: "IPv4 bind and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: otherIPv4Address, - expectRx: true, - }, - { - name: "IPv4 bind to assigned address and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - dstAddr: otherIPv4Address, - expectRx: false, - }, - - { - name: "IPv6 bind and send to assigned address", - addAddress: ipv6ProtocolAddress, - bindAddr: utils.Ipv6Addr.Address, - dstAddr: utils.Ipv6Addr.Address, - expectRx: true, - }, - { - name: "IPv6 bind to wildcard and send to other subnet-local address", - addAddress: ipv6ProtocolAddress, - dstAddr: otherIPv6Address, - expectRx: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - var wq waiter.Queue - rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer rep.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} - if err := rep.Bind(bindAddr); err != nil { - t.Fatalf("rep.Bind(%+v): %s", bindAddr, err) - } - - sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer sep.Close() - - wopts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{ - Addr: test.dstAddr, - Port: localPort, - }, - } - var r bytes.Reader - r.Reset(data) - n, err := sep.Write(&r, wopts) - if err != nil { - t.Fatalf("sep.Write(_, _): %s", err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) - } - - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - if res, err := rep.Read(&buf, opts); test.expectRx { - if err != nil { - t.Fatalf("rep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{ - Addr: test.addAddress.AddressWithPrefix.Address, - }, - }, res, - checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"), - ); diff != "" { - t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address -// in a loopback interface's associated subnet is bound to the permanently bound -// address. -func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { - const nicID = 1 - - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - addrBytes := []byte(utils.Ipv4Addr.Address) - addrBytes[len(addrBytes)-1]++ - otherAddr := tcpip.Address(addrBytes) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - r, err := s.FindRoute(nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, err) - } - defer r.Release() - - params := stack.NetworkHeaderParams{ - Protocol: 111, - TTL: 64, - TOS: stack.DefaultTOS, - } - data := buffer.View([]byte{1, 2, 3, 4}) - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != nil { - t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err) - } - - // Removing the address should make the endpoint invalid. - if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) - } - { - err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })) - if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) - } - } -} - -// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address and TCP -// traffic is sent/received correctly. -func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { - const ( - nicID = 1 - localPort = 80 - ) - - ipv4ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8 - ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) - ipv4Bytes[len(ipv4Bytes)-1]++ - otherIPv4Address := tcpip.Address(ipv4Bytes) - - ipv6ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - ipv6Bytes := []byte(utils.Ipv6Addr.Address) - ipv6Bytes[len(ipv6Bytes)-1]++ - otherIPv6Address := tcpip.Address(ipv6Bytes) - - tests := []struct { - name string - addAddress tcpip.ProtocolAddress - bindAddr tcpip.Address - dstAddr tcpip.Address - expectAccept bool - }{ - { - name: "IPv4 bind to wildcard and send to assigned address", - addAddress: ipv4ProtocolAddress, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectAccept: true, - }, - { - name: "IPv4 bind to wildcard and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - dstAddr: otherIPv4Address, - expectAccept: true, - }, - { - name: "IPv4 bind to wildcard send to other address", - addAddress: ipv4ProtocolAddress, - dstAddr: utils.RemoteIPv4Addr, - expectAccept: false, - }, - { - name: "IPv4 bind to other subnet-local address and send to assigned address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectAccept: false, - }, - { - name: "IPv4 bind and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: otherIPv4Address, - expectAccept: true, - }, - { - name: "IPv4 bind to assigned address and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - dstAddr: otherIPv4Address, - expectAccept: false, - }, - - { - name: "IPv6 bind and send to assigned address", - addAddress: ipv6ProtocolAddress, - bindAddr: utils.Ipv6Addr.Address, - dstAddr: utils.Ipv6Addr.Address, - expectAccept: true, - }, - { - name: "IPv6 bind to wildcard and send to other subnet-local address", - addAddress: ipv6ProtocolAddress, - dstAddr: otherIPv6Address, - expectAccept: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer listeningEndpoint.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} - if err := listeningEndpoint.Bind(bindAddr); err != nil { - t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err) - } - - if err := listeningEndpoint.Listen(1); err != nil { - t.Fatalf("listeningEndpoint.Listen(1): %s", err) - } - - connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer connectingEndpoint.Close() - - connectAddr := tcpip.FullAddress{ - Addr: test.dstAddr, - Port: localPort, - } - { - err := connectingEndpoint.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) - } - } - - if !test.expectAccept { - _, _, err := listeningEndpoint.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - return - } - - // Wait for the listening endpoint to be "readable". That is, wait for a - // new connection. - <-ch - var addr tcpip.FullAddress - if _, _, err := listeningEndpoint.Accept(&addr); err != nil { - t.Fatalf("listeningEndpoint.Accept(nil): %s", err) - } - if addr.Addr != test.addAddress.AddressWithPrefix.Address { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go deleted file mode 100644 index e4439ba79..000000000 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ /dev/null @@ -1,751 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package multicast_broadcast_test - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - defaultMTU = 1280 - ttl = 255 -) - -// TestPingMulticastBroadcast tests that responding to an Echo Request destined -// to a multicast or broadcast address uses a unicast source address for the -// reply. -func TestPingMulticastBroadcast(t *testing.T) { - const nicID = 1 - - rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ttl, - SrcAddr: utils.RemoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, utils.RemoteIPv6Addr, dst, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ttl, - SrcAddr: utils.RemoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - tests := []struct { - name string - dstAddr tcpip.Address - }{ - { - name: "IPv4 unicast", - dstAddr: utils.Ipv4Addr.Address, - }, - { - name: "IPv4 directed broadcast", - dstAddr: utils.Ipv4SubnetBcast, - }, - { - name: "IPv4 broadcast", - dstAddr: header.IPv4Broadcast, - }, - { - name: "IPv4 all-systems multicast", - dstAddr: header.IPv4AllSystems, - }, - { - name: "IPv6 unicast", - dstAddr: utils.Ipv6Addr.Address, - }, - { - name: "IPv6 all-nodes multicast", - dstAddr: header.IPv6AllNodesMulticastAddress, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - }) - // We only expect a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) - } - ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) - } - - // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote - // node when attempting to send the ICMP Echo Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - var rxICMP func(*channel.Endpoint, tcpip.Address) - var expectedSrc tcpip.Address - var expectedDst tcpip.Address - var protoNum tcpip.NetworkProtocolNumber - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - rxICMP = rxIPv4ICMP - expectedSrc = utils.Ipv4Addr.Address - expectedDst = utils.RemoteIPv4Addr - protoNum = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - rxICMP = rxIPv6ICMP - expectedSrc = utils.Ipv6Addr.Address - expectedDst = utils.RemoteIPv6Addr - protoNum = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - rxICMP(e, test.dstAddr) - pkt, ok := e.Read() - if !ok { - t.Fatal("expected ICMP response") - } - - if pkt.Route.LocalAddress != expectedSrc { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) - } - if pkt.Route.RemoteAddress != expectedDst { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) - } - - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if src != expectedSrc { - t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) - } - if dst != expectedDst { - t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) - } - }) - } - -} - -func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { - payloadLen := header.UDPMinimumSize + len(data) - totalLen := header.IPv4MinimumSize + payloadLen - hdr := buffer.NewPrependable(totalLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: utils.RemotePort, - DstPort: utils.LocalPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(udp.ProtocolNumber), - TTL: ttl, - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { - payloadLen := header.UDPMinimumSize + len(data) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: utils.RemotePort, - DstPort: utils.LocalPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - TransportProtocol: udp.ProtocolNumber, - HopLimit: ttl, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some -// multicast or broadcast address. -func TestIncomingMulticastAndBroadcast(t *testing.T) { - const nicID = 1 - - data := []byte{1, 2, 3, 4} - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - localAddr tcpip.AddressWithPrefix - rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool - }{ - { - name: "IPv4 unicast binding to unicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4Addr.Address, - dstAddr: utils.Ipv4Addr.Address, - expectRx: true, - }, - { - name: "IPv4 unicast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: utils.Ipv4Addr.Address, - expectRx: false, - }, - { - name: "IPv4 unicast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4Addr.Address, - expectRx: true, - }, - - { - name: "IPv4 directed broadcast binding to subnet broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4SubnetBcast, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - { - name: "IPv4 directed broadcast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: false, - }, - { - name: "IPv4 directed broadcast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - - { - name: "IPv4 broadcast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: header.IPv4Broadcast, - expectRx: true, - }, - { - name: "IPv4 broadcast binding to subnet broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4SubnetBcast, - dstAddr: header.IPv4Broadcast, - expectRx: false, - }, - { - name: "IPv4 broadcast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - - { - name: "IPv4 all-systems multicast binding to all-systems multicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4AllSystems, - dstAddr: header.IPv4AllSystems, - expectRx: true, - }, - { - name: "IPv4 all-systems multicast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: header.IPv4AllSystems, - expectRx: true, - }, - { - name: "IPv4 all-systems multicast binding to unicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4Addr.Address, - dstAddr: header.IPv4AllSystems, - expectRx: false, - }, - - // IPv6 has no notion of a broadcast. - { - name: "IPv6 unicast binding to wildcard", - dstAddr: utils.Ipv6Addr.Address, - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - expectRx: true, - }, - { - name: "IPv6 broadcast-like address binding to wildcard", - dstAddr: utils.Ipv6SubnetBcast, - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - expectRx: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: utils.LocalPort} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) - } - - test.rxUDP(e, test.remoteAddr, test.dstAddr, data) - var buf bytes.Buffer - var opts tcpip.ReadOptions - if res, err := ep.Read(&buf, opts); test.expectRx { - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all -// interested endpoints. -func TestReuseAddrAndBroadcast(t *testing.T) { - const ( - nicID = 1 - localPort = 9000 - loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") - ) - - tests := []struct { - name string - broadcastAddr tcpip.Address - }{ - { - name: "Subnet directed broadcast", - broadcastAddr: loopbackBroadcast, - }, - { - name: "IPv4 broadcast", - broadcastAddr: header.IPv4Broadcast, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: "\x7f\x00\x00\x01", - PrefixLen: 8, - }, - } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - // We use the empty subnet instead of just the loopback subnet so we - // also have a route to the IPv4 Broadcast address. - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - type endpointAndWaiter struct { - ep tcpip.Endpoint - ch chan struct{} - } - var eps []endpointAndWaiter - // We create endpoints that bind to both the wildcard address and the - // broadcast address to make sure both of these types of "broadcast - // interested" endpoints receive broadcast packets. - for _, bindWildcard := range []bool{false, true} { - // Create multiple endpoints for each type of "broadcast interested" - // endpoint so we can test that all endpoints receive the broadcast - // packet. - for i := 0; i < 2; i++ { - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - defer ep.Close() - - ep.SocketOptions().SetReuseAddress(true) - ep.SocketOptions().SetBroadcast(true) - - bindAddr := tcpip.FullAddress{Port: localPort} - if bindWildcard { - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) - } - } else { - bindAddr.Addr = test.broadcastAddr - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) - } - } - - eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) - } - } - - for i, wep := range eps { - writeOpts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{ - Addr: test.broadcastAddr, - Port: localPort, - }, - } - data := []byte{byte(i), 2, 3, 4} - var r bytes.Reader - r.Reset(data) - if n, err := wep.ep.Write(&r, writeOpts); err != nil { - t.Fatalf("eps[%d].Write(_, _): %s", i, err) - } else if want := int64(len(data)); n != want { - t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) - } - - for j, rep := range eps { - // Wait for the endpoint to become readable. - <-rep.ch - - var buf bytes.Buffer - result, err := rep.ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err) - continue - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff) - } - if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" { - t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) - } - } - } - }) - } -} - -func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { - const ( - nicID = 1 - ) - - data := []byte{1, 2, 3, 4} - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - localAddr tcpip.AddressWithPrefix - rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) - multicastAddr tcpip.Address - }{ - { - name: "IPv4 unicast binding to unicast", - multicastAddr: "\xe0\x01\x02\x03", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - }, - { - name: "IPv6 broadcast-like address binding to wildcard", - multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04", - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - }, - } - - subTests := []struct { - name string - specifyNICID bool - specifyNICAddr bool - }{ - { - name: "Specify NIC ID and NIC address", - specifyNICID: true, - specifyNICAddr: true, - }, - { - name: "Don't specify NIC ID or NIC address", - specifyNICID: false, - specifyNICAddr: false, - }, - { - name: "Specify NIC ID but don't specify NIC address", - specifyNICID: true, - specifyNICAddr: false, - }, - { - name: "Don't specify NIC ID but specify NIC address", - specifyNICID: false, - specifyNICAddr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - - // Set the route table so that UDP can find a NIC that is - // routable to the multicast address when the NIC isn't specified. - if !subTest.specifyNICID && !subTest.specifyNICAddr { - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Port: utils.LocalPort} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) - } - - memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr} - if subTest.specifyNICID { - memOpt.NIC = nicID - } - if subTest.specifyNICAddr { - memOpt.InterfaceAddr = test.localAddr.Address - } - - // We should receive UDP packets to the group once we join the - // multicast group. - addOpt := tcpip.AddMembershipOption(memOpt) - if err := ep.SetSockOpt(&addOpt); err != nil { - t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err) - } - test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("ep.Read: %s", err) - } else { - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } - - // We should not receive UDP packets to the group once we leave - // the multicast group. - removeOpt := tcpip.RemoveMembershipOption(memOpt) - if err := ep.SetSockOpt(&removeOpt); err != nil { - t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) - } - { - _, err := ep.Read(&buf, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) - } - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go deleted file mode 100644 index 4455f6dd7..000000000 --- a/pkg/tcpip/tests/integration/route_test.go +++ /dev/null @@ -1,430 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package route_test - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// TestLocalPing tests pinging a remote that is local the stack. -// -// This tests that a local route is created and packets do not leave the stack. -func TestLocalPing(t *testing.T) { - const ( - nicID = 1 - ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") - - // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo - // request/reply packets. - icmpDataOffset = 8 - ) - - channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } - channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { - channelEP := e.(*channel.Endpoint) - if n := channelEP.Drain(); n != 0 { - t.Fatalf("got channelEP.Drain() = %d, want = 0", n) - } - } - - ipv4ICMPBuf := func(t *testing.T) buffer.View { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) - hdr.SetType(header.ICMPv4Echo) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return buffer.View(hdr) - } - - ipv6ICMPBuf := func(t *testing.T) buffer.View { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} - hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) - hdr.SetType(header.ICMPv6EchoRequest) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return buffer.View(hdr) - } - - tests := []struct { - name string - transProto tcpip.TransportProtocolNumber - netProto tcpip.NetworkProtocolNumber - linkEndpoint func() stack.LinkEndpoint - localAddr tcpip.Address - icmpBuf func(*testing.T) buffer.View - expectedConnectErr tcpip.Error - checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) - }{ - { - name: "IPv4 loopback", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: loopback.New, - localAddr: ipv4Loopback, - icmpBuf: ipv4ICMPBuf, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv6 loopback", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: loopback.New, - localAddr: header.IPv6Loopback, - icmpBuf: ipv6ICMPBuf, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv4 non-loopback", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: channelEP, - localAddr: utils.Ipv4Addr.Address, - icmpBuf: ipv4ICMPBuf, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv6 non-loopback", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: channelEP, - localAddr: utils.Ipv6Addr.Address, - icmpBuf: ipv6ICMPBuf, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv4 loopback without local address", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: loopback.New, - icmpBuf: ipv4ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv6 loopback without local address", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: loopback.New, - icmpBuf: ipv6ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv4 non-loopback without local address", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: channelEP, - icmpBuf: ipv4ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv6 non-loopback without local address", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: channelEP, - icmpBuf: ipv6ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: channelEPCheck, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - e := test.linkEndpoint() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) - } - } - - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - connAddr := tcpip.FullAddress{Addr: test.localAddr} - { - err := ep.Connect(connAddr) - if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) - } - } - - if test.expectedConnectErr != nil { - return - } - - payload := test.icmpBuf(t) - var r bytes.Reader - r.Reset(payload) - var wOpts tcpip.WriteOptions - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) - } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) - } - - // Wait for the endpoint to become readable. - <-ch - - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - test.checkLinkEndpoint(t, e) - }) - } -} - -// TestLocalUDP tests sending UDP packets between two endpoints that are local -// to the stack. -// -// This tests that that packets never leave the stack and the addresses -// used when sending a packet. -func TestLocalUDP(t *testing.T) { - const ( - nicID = 1 - ) - - tests := []struct { - name string - canBePrimaryAddr tcpip.ProtocolAddress - firstPrimaryAddr tcpip.ProtocolAddress - }{ - { - name: "IPv4", - canBePrimaryAddr: utils.Ipv4Addr1, - firstPrimaryAddr: utils.Ipv4Addr2, - }, - { - name: "IPv6", - canBePrimaryAddr: utils.Ipv6Addr1, - firstPrimaryAddr: utils.Ipv6Addr2, - }, - } - - subTests := []struct { - name string - addAddress bool - expectedWriteErr tcpip.Error - }{ - { - name: "Unassigned local address", - addAddress: false, - expectedWriteErr: &tcpip.ErrNoRoute{}, - }, - { - name: "Assigned local address", - addAddress: true, - expectedWriteErr: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: true, - } - - s := stack.New(stackOpts) - ep := channel.New(1, header.IPv6MinimumMTU, "") - - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if subTest.addAddress { - if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) - } - if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) - } - } - - var serverWQ waiter.Queue - serverWE, serverCH := waiter.NewChannelEntry(nil) - serverWQ.EventRegister(&serverWE, waiter.EventIn) - server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) - } - defer server.Close() - - bindAddr := tcpip.FullAddress{Port: 80} - if err := server.Bind(bindAddr); err != nil { - t.Fatalf("server.Bind(%#v): %s", bindAddr, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.EventIn) - client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) - } - defer client.Close() - - serverAddr := tcpip.FullAddress{ - Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, - Port: 80, - } - - clientPayload := []byte{1, 2, 3, 4} - { - var r bytes.Reader - r.Reset(clientPayload) - wOpts := tcpip.WriteOptions{ - To: &serverAddr, - } - if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { - t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) - } else if subTest.expectedWriteErr != nil { - // Nothing else to test if we expected not to be able to send the - // UDP packet. - return - } else if n != int64(len(clientPayload)) { - t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload)) - } - } - - // Wait for the server endpoint to become readable. - <-serverCH - - var clientAddr tcpip.FullAddress - var readBuf bytes.Buffer - if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { - t.Fatalf("server.Read(_): %s", err) - } else { - clientAddr = read.RemoteAddr - - if diff := cmp.Diff(tcpip.ReadResult{ - Count: readBuf.Len(), - Total: readBuf.Len(), - RemoteAddr: tcpip.FullAddress{ - Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, - }, - }, read, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" { - t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - } - - serverPayload := []byte{1, 2, 3, 4} - { - var r bytes.Reader - r.Reset(serverPayload) - wOpts := tcpip.WriteOptions{ - To: &clientAddr, - } - if n, err := server.Write(&r, wOpts); err != nil { - t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) - } else if n != int64(len(serverPayload)) { - t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) - } - } - - // Wait for the client endpoint to become readable. - <-clientCH - - readBuf.Reset() - if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { - t.Fatalf("client.Read(_): %s", err) - } else { - if diff := cmp.Diff(tcpip.ReadResult{ - Count: readBuf.Len(), - Total: readBuf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr}, - }, read, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" { - t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD deleted file mode 100644 index 433004148..000000000 --- a/pkg/tcpip/tests/utils/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "utils", - srcs = ["utils.go"], - visibility = ["//pkg/tcpip/tests:__subpackages__"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/ethernet", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/link/pipe", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go deleted file mode 100644 index f414a2234..000000000 --- a/pkg/tcpip/tests/utils/utils.go +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package utils holds common testing utilities for tcpip. -package utils - -import ( - "net" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" - "gvisor.dev/gvisor/pkg/tcpip/link/nested" - "gvisor.dev/gvisor/pkg/tcpip/link/pipe" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// Common NIC IDs used by tests. -const ( - Host1NICID = 1 - RouterNICID1 = 2 - RouterNICID2 = 3 - Host2NICID = 4 -) - -// Common link addresses used by tests. -const ( - LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - LinkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") - LinkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") - LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") -) - -// Common IP addresses used by tests. -var ( - Ipv4Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - } - Ipv4Subnet = Ipv4Addr.Subnet() - Ipv4SubnetBcast = Ipv4Subnet.Broadcast() - - Ipv6Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("200a::1").To16()), - PrefixLen: 64, - } - Ipv6Subnet = Ipv6Addr.Subnet() - Ipv6SubnetBcast = Ipv6Subnet.Broadcast() - - Ipv4Addr1 = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - Ipv4Addr2 = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 8, - }, - } - Ipv4Addr3 = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.3").To4()), - PrefixLen: 8, - }, - } - Ipv6Addr1 = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::1").To16()), - PrefixLen: 64, - }, - } - Ipv6Addr2 = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::2").To16()), - PrefixLen: 64, - }, - } - Ipv6Addr3 = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::3").To16()), - PrefixLen: 64, - }, - } - - // Remote addrs. - RemoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) - RemoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16()) -) - -// Common ports for testing. -const ( - RemotePort = 5555 - LocalPort = 80 -) - -// Common IP addresses used for testing. -var ( - Host1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 24, - }, - } - RouterNIC1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - RouterNIC2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - }, - } - Host2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), - PrefixLen: 8, - }, - } - Host1IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::2").To16()), - PrefixLen: 64, - }, - } - RouterNIC1IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::1").To16()), - PrefixLen: 64, - }, - } - RouterNIC2IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("b::1").To16()), - PrefixLen: 64, - }, - } - Host2IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("b::2").To16()), - PrefixLen: 64, - }, - } -) - -// NewEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func NewEthernetEndpoint(ep stack.LinkEndpoint) *EndpointWithDestinationCheck { - var e EndpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// EndpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type EndpointWithDestinationCheck struct { - nested.Endpoint -} - -var _ stack.NetworkDispatcher = (*EndpointWithDestinationCheck)(nil) -var _ stack.LinkEndpoint = (*EndpointWithDestinationCheck)(nil) - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *EndpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -// SetupRoutedStacks creates the NICs, sets forwarding, adds addresses and sets -// the route tables for the passed stacks. -func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { - host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) - routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) - } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) - } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) - } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: Host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: Host1NICID, - }, - { - Destination: Host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: Host1NICID, - }, - { - Destination: Host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: Host1NICID, - }, - { - Destination: Host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: Host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: RouterNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID1, - }, - { - Destination: RouterNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID1, - }, - { - Destination: RouterNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID2, - }, - { - Destination: RouterNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: Host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: Host2NICID, - }, - { - Destination: Host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: Host2NICID, - }, - { - Destination: Host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: Host2NICID, - }, - { - Destination: Host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: Host2NICID, - }, - }) -} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go deleted file mode 100644 index a82384c49..000000000 --- a/pkg/tcpip/timer_test.go +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip_test - -import ( - "sync" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -const ( - shortDuration = 1 * time.Nanosecond - middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second -) - -func TestJobReschedule(t *testing.T) { - var clock tcpip.StdClock - var wg sync.WaitGroup - var lock sync.Mutex - - for i := 0; i < 2; i++ { - wg.Add(1) - - go func() { - lock.Lock() - // Assigning a new timer value updates the timer's locker and function. - // This test makes sure there is no data race when reassigning a timer - // that has an active timer (even if it has been stopped as a stopped - // timer may be blocked on a lock before it can check if it has been - // stopped while another goroutine holds the same lock). - job := tcpip.NewJob(&clock, &lock, func() { - wg.Done() - }) - job.Schedule(shortDuration) - lock.Unlock() - }() - } - wg.Wait() -} - -func TestJobExecution(t *testing.T) { - t.Parallel() - - var clock tcpip.StdClock - var lock sync.Mutex - ch := make(chan struct{}) - - job := tcpip.NewJob(&clock, &lock, func() { - ch <- struct{}{} - }) - job.Schedule(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestCancellableTimerResetFromLongDuration(t *testing.T) { - t.Parallel() - - var clock tcpip.StdClock - var lock sync.Mutex - ch := make(chan struct{}) - - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(middleDuration) - - lock.Lock() - job.Cancel() - lock.Unlock() - - job.Schedule(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestJobRescheduleFromShortDuration(t *testing.T) { - t.Parallel() - - var clock tcpip.StdClock - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - job.Cancel() - lock.Unlock() - - // Wait for timer to fire if it wasn't correctly stopped. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): - } - - job.Schedule(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestJobImmediatelyCancel(t *testing.T) { - t.Parallel() - - var clock tcpip.StdClock - var lock sync.Mutex - ch := make(chan struct{}) - - for i := 0; i < 1000; i++ { - lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - job.Cancel() - lock.Unlock() - } - - // Wait for timer to fire if it wasn't correctly stopped. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): - } -} - -func TestJobCancelledRescheduleWithoutLock(t *testing.T) { - t.Parallel() - - var clock tcpip.StdClock - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - job.Cancel() - lock.Unlock() - - for i := 0; i < 10; i++ { - job.Schedule(middleDuration) - - lock.Lock() - // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration * 2) - job.Cancel() - lock.Unlock() - } - - // Wait for double the duration so timers that weren't correctly stopped can - // fire. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration * 2): - } -} - -func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { - t.Parallel() - - var clock tcpip.StdClock - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - for i := 0; i < 10; i++ { - // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration) - job.Cancel() - job.Schedule(shortDuration) - } - lock.Unlock() - - // Wait for double the duration for the last timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} - -func TestManyJobReschedulesUnderLock(t *testing.T) { - t.Parallel() - - var clock tcpip.StdClock - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - for i := 0; i < 10; i++ { - job.Cancel() - job.Schedule(shortDuration) - } - lock.Unlock() - - // Wait for double the duration for the last timer to fire. - select { - case <-ch: - case <-time.After(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): - } -} diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD deleted file mode 100644 index 7e5c79776..000000000 --- a/pkg/tcpip/transport/icmp/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "icmp_packet_list", - out = "icmp_packet_list.go", - package = "icmp", - prefix = "icmpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*icmpPacket", - "Linker": "*icmpPacket", - }, -) - -go_library( - name = "icmp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "icmp_packet_list.go", - "protocol.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go new file mode 100644 index 000000000..0aacdad3f --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -0,0 +1,221 @@ +package icmp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type icmpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type icmpPacketList struct { + head *icmpPacket + tail *icmpPacket +} + +// Reset resets list l to the empty state. +func (l *icmpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *icmpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *icmpPacketList) Front() *icmpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *icmpPacketList) Back() *icmpPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *icmpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (icmpPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *icmpPacketList) PushFront(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *icmpPacketList) PushBack(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *icmpPacketList) PushBackList(m *icmpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) { + bLinker := icmpPacketElementMapper{}.linkerFor(b) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) { + aLinker := icmpPacketElementMapper{}.linkerFor(a) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *icmpPacketList) Remove(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + icmpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type icmpPacketEntry struct { + next *icmpPacket + prev *icmpPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) Next() *icmpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) Prev() *icmpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) SetNext(elem *icmpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go new file mode 100644 index 000000000..fe5af3d97 --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -0,0 +1,160 @@ +// automatically generated by stateify. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *icmpPacket) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacket" +} + +func (p *icmpPacket) StateFields() []string { + return []string{ + "icmpPacketEntry", + "senderAddress", + "data", + "timestamp", + } +} + +func (p *icmpPacket) beforeSave() {} + +func (p *icmpPacket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView = p.saveData() + stateSinkObject.SaveValue(2, dataValue) + stateSinkObject.Save(0, &p.icmpPacketEntry) + stateSinkObject.Save(1, &p.senderAddress) + stateSinkObject.Save(3, &p.timestamp) +} + +func (p *icmpPacket) afterLoad() {} + +func (p *icmpPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.icmpPacketEntry) + stateSourceObject.Load(1, &p.senderAddress) + stateSourceObject.Load(3, &p.timestamp) + stateSourceObject.LoadValue(2, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/icmp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "rcvReady", + "rcvList", + "rcvBufSizeMax", + "rcvBufSize", + "rcvClosed", + "shutdownFlags", + "state", + "ttl", + "owner", + "ops", + } +} + +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() + stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.uniqueID) + stateSinkObject.Save(4, &e.rcvReady) + stateSinkObject.Save(5, &e.rcvList) + stateSinkObject.Save(7, &e.rcvBufSize) + stateSinkObject.Save(8, &e.rcvClosed) + stateSinkObject.Save(9, &e.shutdownFlags) + stateSinkObject.Save(10, &e.state) + stateSinkObject.Save(11, &e.ttl) + stateSinkObject.Save(12, &e.owner) + stateSinkObject.Save(13, &e.ops) +} + +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.uniqueID) + stateSourceObject.Load(4, &e.rcvReady) + stateSourceObject.Load(5, &e.rcvList) + stateSourceObject.Load(7, &e.rcvBufSize) + stateSourceObject.Load(8, &e.rcvClosed) + stateSourceObject.Load(9, &e.shutdownFlags) + stateSourceObject.Load(10, &e.state) + stateSourceObject.Load(11, &e.ttl) + stateSourceObject.Load(12, &e.owner) + stateSourceObject.Load(13, &e.ops) + stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (l *icmpPacketList) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacketList" +} + +func (l *icmpPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *icmpPacketList) beforeSave() {} + +func (l *icmpPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *icmpPacketList) afterLoad() {} + +func (l *icmpPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *icmpPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacketEntry" +} + +func (e *icmpPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *icmpPacketEntry) beforeSave() {} + +func (e *icmpPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *icmpPacketEntry) afterLoad() {} + +func (e *icmpPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*icmpPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*icmpPacketList)(nil)) + state.Register((*icmpPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD deleted file mode 100644 index b989b1209..000000000 --- a/pkg/tcpip/transport/packet/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "packet_list", - out = "packet_list.go", - package = "packet", - prefix = "packet", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*packet", - "Linker": "*packet", - }, -) - -go_library( - name = "packet", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go new file mode 100644 index 000000000..2c983aad0 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_list.go @@ -0,0 +1,221 @@ +package packet + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type packetElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (packetElementMapper) linkerFor(elem *packet) *packet { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type packetList struct { + head *packet + tail *packet +} + +// Reset resets list l to the empty state. +func (l *packetList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *packetList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *packetList) Front() *packet { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *packetList) Back() *packet { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *packetList) Len() (count int) { + for e := l.Front(); e != nil; e = (packetElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *packetList) PushFront(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + packetElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *packetList) PushBack(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *packetList) PushBackList(m *packetList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(m.head) + packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *packetList) InsertAfter(b, e *packet) { + bLinker := packetElementMapper{}.linkerFor(b) + eLinker := packetElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + packetElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *packetList) InsertBefore(a, e *packet) { + aLinker := packetElementMapper{}.linkerFor(a) + eLinker := packetElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + packetElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *packetList) Remove(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + packetElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + packetElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type packetEntry struct { + next *packet + prev *packet +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *packetEntry) Next() *packet { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *packetEntry) Prev() *packet { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *packetEntry) SetNext(elem *packet) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *packetEntry) SetPrev(elem *packet) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go new file mode 100644 index 000000000..e78427555 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -0,0 +1,163 @@ +// automatically generated by stateify. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *packet) StateTypeName() string { + return "pkg/tcpip/transport/packet.packet" +} + +func (p *packet) StateFields() []string { + return []string{ + "packetEntry", + "data", + "timestampNS", + "senderAddr", + "packetInfo", + } +} + +func (p *packet) beforeSave() {} + +func (p *packet) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView = p.saveData() + stateSinkObject.SaveValue(1, dataValue) + stateSinkObject.Save(0, &p.packetEntry) + stateSinkObject.Save(2, &p.timestampNS) + stateSinkObject.Save(3, &p.senderAddr) + stateSinkObject.Save(4, &p.packetInfo) +} + +func (p *packet) afterLoad() {} + +func (p *packet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.packetEntry) + stateSourceObject.Load(2, &p.timestampNS) + stateSourceObject.Load(3, &p.senderAddr) + stateSourceObject.Load(4, &p.packetInfo) + stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) +} + +func (ep *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/packet.endpoint" +} + +func (ep *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "netProto", + "waiterQueue", + "cooked", + "rcvList", + "rcvBufSizeMax", + "rcvBufSize", + "rcvClosed", + "closed", + "bound", + "boundNIC", + "lastError", + "ops", + } +} + +func (ep *endpoint) StateSave(stateSinkObject state.Sink) { + ep.beforeSave() + var rcvBufSizeMaxValue int = ep.saveRcvBufSizeMax() + stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) + stateSinkObject.Save(0, &ep.TransportEndpointInfo) + stateSinkObject.Save(1, &ep.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &ep.netProto) + stateSinkObject.Save(3, &ep.waiterQueue) + stateSinkObject.Save(4, &ep.cooked) + stateSinkObject.Save(5, &ep.rcvList) + stateSinkObject.Save(7, &ep.rcvBufSize) + stateSinkObject.Save(8, &ep.rcvClosed) + stateSinkObject.Save(9, &ep.closed) + stateSinkObject.Save(10, &ep.bound) + stateSinkObject.Save(11, &ep.boundNIC) + stateSinkObject.Save(12, &ep.lastError) + stateSinkObject.Save(13, &ep.ops) +} + +func (ep *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ep.TransportEndpointInfo) + stateSourceObject.Load(1, &ep.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &ep.netProto) + stateSourceObject.Load(3, &ep.waiterQueue) + stateSourceObject.Load(4, &ep.cooked) + stateSourceObject.Load(5, &ep.rcvList) + stateSourceObject.Load(7, &ep.rcvBufSize) + stateSourceObject.Load(8, &ep.rcvClosed) + stateSourceObject.Load(9, &ep.closed) + stateSourceObject.Load(10, &ep.bound) + stateSourceObject.Load(11, &ep.boundNIC) + stateSourceObject.Load(12, &ep.lastError) + stateSourceObject.Load(13, &ep.ops) + stateSourceObject.LoadValue(6, new(int), func(y interface{}) { ep.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.AfterLoad(ep.afterLoad) +} + +func (l *packetList) StateTypeName() string { + return "pkg/tcpip/transport/packet.packetList" +} + +func (l *packetList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *packetList) beforeSave() {} + +func (l *packetList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *packetList) afterLoad() {} + +func (l *packetList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *packetEntry) StateTypeName() string { + return "pkg/tcpip/transport/packet.packetEntry" +} + +func (e *packetEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *packetEntry) beforeSave() {} + +func (e *packetEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *packetEntry) afterLoad() {} + +func (e *packetEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*packet)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*packetList)(nil)) + state.Register((*packetEntry)(nil)) +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD deleted file mode 100644 index 2eab09088..000000000 --- a/pkg/tcpip/transport/raw/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "raw_packet_list", - out = "raw_packet_list.go", - package = "raw", - prefix = "rawPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*rawPacket", - "Linker": "*rawPacket", - }, -) - -go_library( - name = "raw", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "protocol.go", - "raw_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/packet", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go new file mode 100644 index 000000000..48804ff1b --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_packet_list.go @@ -0,0 +1,221 @@ +package raw + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type rawPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type rawPacketList struct { + head *rawPacket + tail *rawPacket +} + +// Reset resets list l to the empty state. +func (l *rawPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *rawPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *rawPacketList) Front() *rawPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *rawPacketList) Back() *rawPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *rawPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (rawPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *rawPacketList) PushFront(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *rawPacketList) PushBack(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *rawPacketList) PushBackList(m *rawPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *rawPacketList) InsertAfter(b, e *rawPacket) { + bLinker := rawPacketElementMapper{}.linkerFor(b) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + rawPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *rawPacketList) InsertBefore(a, e *rawPacket) { + aLinker := rawPacketElementMapper{}.linkerFor(a) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + rawPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *rawPacketList) Remove(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + rawPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + rawPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type rawPacketEntry struct { + next *rawPacket + prev *rawPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *rawPacketEntry) Next() *rawPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *rawPacketEntry) Prev() *rawPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *rawPacketEntry) SetNext(elem *rawPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *rawPacketEntry) SetPrev(elem *rawPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go new file mode 100644 index 000000000..db4b393a7 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -0,0 +1,157 @@ +// automatically generated by stateify. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *rawPacket) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacket" +} + +func (p *rawPacket) StateFields() []string { + return []string{ + "rawPacketEntry", + "data", + "timestampNS", + "senderAddr", + } +} + +func (p *rawPacket) beforeSave() {} + +func (p *rawPacket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView = p.saveData() + stateSinkObject.SaveValue(1, dataValue) + stateSinkObject.Save(0, &p.rawPacketEntry) + stateSinkObject.Save(2, &p.timestampNS) + stateSinkObject.Save(3, &p.senderAddr) +} + +func (p *rawPacket) afterLoad() {} + +func (p *rawPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.rawPacketEntry) + stateSourceObject.Load(2, &p.timestampNS) + stateSourceObject.Load(3, &p.senderAddr) + stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/raw.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "associated", + "rcvList", + "rcvBufSize", + "rcvBufSizeMax", + "rcvClosed", + "closed", + "connected", + "bound", + "owner", + "ops", + } +} + +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() + stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.associated) + stateSinkObject.Save(4, &e.rcvList) + stateSinkObject.Save(5, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.closed) + stateSinkObject.Save(9, &e.connected) + stateSinkObject.Save(10, &e.bound) + stateSinkObject.Save(11, &e.owner) + stateSinkObject.Save(12, &e.ops) +} + +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.associated) + stateSourceObject.Load(4, &e.rcvList) + stateSourceObject.Load(5, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.closed) + stateSourceObject.Load(9, &e.connected) + stateSourceObject.Load(10, &e.bound) + stateSourceObject.Load(11, &e.owner) + stateSourceObject.Load(12, &e.ops) + stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (l *rawPacketList) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacketList" +} + +func (l *rawPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *rawPacketList) beforeSave() {} + +func (l *rawPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *rawPacketList) afterLoad() {} + +func (l *rawPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *rawPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacketEntry" +} + +func (e *rawPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *rawPacketEntry) beforeSave() {} + +func (e *rawPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *rawPacketEntry) afterLoad() {} + +func (e *rawPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*rawPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*rawPacketList)(nil)) + state.Register((*rawPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD deleted file mode 100644 index fcdd032c5..000000000 --- a/pkg/tcpip/transport/tcp/BUILD +++ /dev/null @@ -1,134 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "more_shards") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "tcp_segment_list", - out = "tcp_segment_list.go", - package = "tcp", - prefix = "segment", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*segment", - "Linker": "*segment", - }, -) - -go_template_instance( - name = "tcp_endpoint_list", - out = "tcp_endpoint_list.go", - package = "tcp", - prefix = "endpoint", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*endpoint", - "Linker": "*endpoint", - }, -) - -go_library( - name = "tcp", - srcs = [ - "accept.go", - "connect.go", - "connect_unsafe.go", - "cubic.go", - "cubic_state.go", - "dispatcher.go", - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "rack.go", - "rack_state.go", - "rcv.go", - "rcv_state.go", - "reno.go", - "reno_recovery.go", - "sack.go", - "sack_recovery.go", - "sack_scoreboard.go", - "segment.go", - "segment_heap.go", - "segment_queue.go", - "segment_state.go", - "segment_unsafe.go", - "snd.go", - "snd_state.go", - "tcp_endpoint_list.go", - "tcp_segment_list.go", - "timer.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "tcp_x_test", - size = "medium", - srcs = [ - "dual_stack_test.go", - "sack_scoreboard_test.go", - "tcp_noracedetector_test.go", - "tcp_rack_test.go", - "tcp_sack_test.go", - "tcp_test.go", - "tcp_timestamp_test.go", - ], - shard_count = more_shards, - deps = [ - ":tcp", - "//pkg/rand", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp/testing/context", - "//pkg/test/testutil", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcp_test", - size = "small", - srcs = ["timer_test.go"], - library = ":tcp", - deps = ["//pkg/sleep"], -) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go deleted file mode 100644 index 2d90246e4..000000000 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ /dev/null @@ -1,651 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestV4MappedConnectOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Start connection attempt, it must fail. - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) - } -} - -func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv4(t, b, synCheckers...) - - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - )) - checker.IPv4(t, c.GetPacket(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV4MappedConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt to IPv6 address. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetV6Packet() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv6(t, b, synCheckers...) - - tcp := header.TCP(header.IPv6(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - )) - checker.IPv6(t, c.GetV6Packet(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV6Connect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { - c := context.NewWithOpts(t, context.Options{ - EnableV6: true, - MTU: defaultMTU, - }) - defer c.Cleanup() - - // Create a v6 endpoint but don't set the v6-only TCP option. - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to local address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV4RefuseOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) -} - -func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) -} - -func testV4Accept(t *testing.T, c *context.Context) { - c.SetGSOEnabled(true) - defer c.SetGSOEnabled(false) - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv4(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestAddr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) - } - - var r strings.Reader - data := "Don't panic" - r.Reset(data) - nep.Write(&r, tcpip.WriteOptions{}) - b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestV4AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV6AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv6(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - var addr tcpip.FullAddress - _, _, err := c.EP.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(&addr) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if addr.Addr != context.TestV6Addr { - t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) - } -} - -func TestV4AcceptOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func testV4ListenClose(t *testing.T, c *context.Context) { - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - const n = uint16(32) - - // Start listening. - if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - irs := seqnum.Value(789) - for i := uint16(0); i < n; i++ { - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + i, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Each of these ACK's will cause a syn-cookie based connection to be - // accepted and delivered to the listening endpoint. - for i := uint16(0); i < n; i++ { - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - } - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - nep.Close() - c.EP.Close() -} - -func TestV4ListenCloseOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4ListenClose(t, c) -} diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go deleted file mode 100644 index 8a026ec46..000000000 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rcv_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -func TestAcceptable(t *testing.T) { - for _, tt := range []struct { - segSeq seqnum.Value - segLen seqnum.Size - rcvNxt, rcvAcc seqnum.Value - want bool - }{ - // The segment is smaller than the window. - {105, 2, 100, 104, false}, - {105, 2, 101, 105, true}, - {105, 2, 102, 106, true}, - {105, 2, 103, 107, true}, - {105, 2, 104, 108, true}, - {105, 2, 105, 109, true}, - {105, 2, 106, 110, true}, - {105, 2, 107, 111, false}, - - // The segment is larger than the window. - {105, 4, 103, 105, true}, - {105, 4, 104, 106, true}, - {105, 4, 105, 107, true}, - {105, 4, 106, 108, true}, - {105, 4, 107, 109, true}, - {105, 4, 108, 110, true}, - {105, 4, 109, 111, false}, - {105, 4, 110, 112, false}, - - // The segment has no width. - {105, 0, 100, 102, false}, - {105, 0, 101, 103, false}, - {105, 0, 102, 104, false}, - {105, 0, 103, 105, true}, - {105, 0, 104, 106, true}, - {105, 0, 105, 107, true}, - {105, 0, 106, 108, false}, - {105, 0, 107, 109, false}, - - // The receive window has no width. - {105, 2, 103, 103, false}, - {105, 2, 104, 104, false}, - {105, 2, 105, 105, false}, - {105, 2, 106, 106, false}, - {105, 2, 107, 107, false}, - {105, 2, 108, 108, false}, - {105, 2, 109, 109, false}, - } { - if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { - t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) - } - } -} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go deleted file mode 100644 index b4e5ba0df..000000000 --- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -const smss = 1500 - -func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard { - s := tcp.NewSACKScoreboard(smss, iss) - for _, blk := range blocks { - s.Insert(blk) - } - return s -} - -func TestSACKScoreboardIsSACKED(t *testing.T) { - type blockTest struct { - block header.SACKBlock - sacked bool - } - testCases := []struct { - comment string - scoreboardBlocks []header.SACKBlock - blockTests []blockTest - iss seqnum.Value - }{ - { - "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", - []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, - []blockTest{ - {header.SACKBlock{15, 21}, true}, - {header.SACKBlock{200, 201}, false}, - {header.SACKBlock{50, 51}, false}, - {header.SACKBlock{53, 120}, true}, - }, - 0, - }, - { - "Test disjoint SACKBlocks", - []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, - []blockTest{ - {header.SACKBlock{2288624809, 2288810057}, true}, - {header.SACKBlock{2288811477, 2288838565}, true}, - {header.SACKBlock{2288810057, 2288811477}, false}, - }, - 2288624809, - }, - { - "Test sequence number wrap around", - []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, - []blockTest{ - {header.SACKBlock{4294254144, 4294254145}, true}, - {header.SACKBlock{4294254143, 4294254144}, false}, - {header.SACKBlock{4294254144, 1}, true}, - {header.SACKBlock{225652, 5350509}, false}, - {header.SACKBlock{5340409, 5350509}, true}, - {header.SACKBlock{5350509, 5350609}, false}, - }, - 4294254144, - }, - { - "Test disjoint SACKBlocks out of order", - []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, - []blockTest{ - {header.SACKBlock{827426028, 827428867}, true}, - {header.SACKBlock{827450168, 827450275}, false}, - }, - 827426000, - }, - } - for _, tc := range testCases { - sb := initScoreboard(tc.scoreboardBlocks, tc.iss) - for _, blkTest := range tc.blockTests { - if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { - t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) - } - } - } -} - -func TestSACKScoreboardIsRangeLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - block header.SACKBlock - lost bool - }{ - // Block not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number covered by this block. - {block: header.SACKBlock{0, 1}, lost: true}, - - // These blocks have all been SACKed and should not be - // considered lost. - {block: header.SACKBlock{1, 2}, lost: false}, - {block: header.SACKBlock{25, 26}, lost: false}, - {block: header.SACKBlock{1, 45}, lost: false}, - - // Same as the first case above. - {block: header.SACKBlock{50, 51}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{119, 120}, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {block: header.SACKBlock{120, 121}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{125, 126}, lost: false}, - - // This block has not been SACKed and there are nDupAckThreshold - // number of SACKed blocks after it. - {block: header.SACKBlock{141, 145}, lost: true}, - - // This block has not been SACKed and there are less than - // nDupAckThreshold SACKed sequences after it. - {block: header.SACKBlock{151, 152}, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { - t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) - } - } -} - -func TestSACKScoreboardIsLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - seq seqnum.Value - lost bool - }{ - // Sequence number not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number. - {seq: 0, lost: true}, - - // These sequence numbers have all been SACKed and should not be - // considered lost. - {seq: 1, lost: false}, - {seq: 25, lost: false}, - {seq: 45, lost: false}, - - // Same as first case above. - {seq: 50, lost: true}, - - // This block has been SACKed and should not be considered lost. - {seq: 119, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {seq: 120, lost: true}, - - // This sequence number has been SACKed and should not be - // considered lost. - {seq: 125, lost: false}, - - // This sequence number has not been SACKed and there are - // nDupAckThreshold number of SACKed blocks after it. - {seq: 141, lost: true}, - - // This sequence number has not been SACKed and there are less - // than nDupAckThreshold SACKed sequences after it. - {seq: 151, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.seq); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) - } - } -} - -func TestSACKScoreboardDelete(t *testing.T) { - blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} - s := initScoreboard(blocks, 4294254143) - s.Delete(5340408) - if s.Empty() { - t.Fatalf("s.Empty() = true, want false") - } - if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } - s.Delete(5340410) - if s.Empty() { - t.Fatal("s.Empty() = true, want false") - } - newSB := header.SACKBlock{5340410, 5350509} - if !s.IsSACKED(newSB) { - t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) - } - s.Delete(5350509) - lastOctet := header.SACKBlock{5350508, 5350509} - if s.IsSACKED(lastOctet) { - t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) - } - - s.Delete(5350510) - if !s.Empty() { - t.Fatal("s.Empty() = false, want true") - } - if got, want := s.Sacked(), seqnum.Size(0); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go new file mode 100644 index 000000000..a7dc5df81 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go @@ -0,0 +1,221 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type endpointElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type endpointList struct { + head *endpoint + tail *endpoint +} + +// Reset resets list l to the empty state. +func (l *endpointList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *endpointList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *endpointList) Front() *endpoint { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *endpointList) Back() *endpoint { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *endpointList) Len() (count int) { + for e := l.Front(); e != nil; e = (endpointElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *endpointList) PushFront(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + endpointElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *endpointList) PushBack(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *endpointList) PushBackList(m *endpointList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head) + endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *endpointList) InsertAfter(b, e *endpoint) { + bLinker := endpointElementMapper{}.linkerFor(b) + eLinker := endpointElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + endpointElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *endpointList) InsertBefore(a, e *endpoint) { + aLinker := endpointElementMapper{}.linkerFor(a) + eLinker := endpointElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + endpointElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *endpointList) Remove(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + endpointElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + endpointElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type endpointEntry struct { + next *endpoint + prev *endpoint +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *endpointEntry) Next() *endpoint { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *endpointEntry) Prev() *endpoint { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *endpointEntry) SetNext(elem *endpoint) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *endpointEntry) SetPrev(elem *endpoint) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go deleted file mode 100644 index ced3a9c58..000000000 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ /dev/null @@ -1,558 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// These tests are flaky when run under the go race detector due to some -// iterations taking long enough that the retransmit timer can kick in causing -// the congestion window measurements to fail due to extra packets etc. -// -// +build !race - -package tcp_test - -import ( - "bytes" - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Wait before checking metrics. - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Wait before checking metrics. - metricPollFn = func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 - } - expected++ - } -} - -func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // Double the number of expected packets for the next iteration. - expected *= 2 - } -} - -func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } -} - -// cubicCwnd returns an estimate of a cubic window given the -// originalCwnd, wMax, last congestion event time and sRTT. -func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { - cwnd := float64(origCwnd) - // We wait 50ms between each iteration so sRTT as computed by cubic - // should be close to 50ms. - elapsed := (time.Since(congEventTime) + sRTT).Seconds() - k := math.Cbrt(float64(wMax) * 0.3 / 0.7) - wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) - cwnd += (wtRTT - cwnd) / cwnd - return int(cwnd) -} - -func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - enableCUBIC(t, c) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) - - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break - } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - } -} - -func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data[:len(data)/2]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - r.Reset(data[len(data)/2:]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) - } - - return nil - } - - // Poll when checking metrics. - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } - - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) -} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go deleted file mode 100644 index 3c13fc8a3..000000000 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ /dev/null @@ -1,989 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -const ( - maxPayload = 10 - tsOptionSize = 12 - maxTCPOptionSize = 40 - mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload -) - -func setStackRACKPermitted(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) - if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) - } -} - -// TestRACKUpdate tests the RACK related fields are updated when an ACK is -// received on a SACK enabled connection. -func TestRACKUpdate(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - var xmitTime time.Time - probeDone := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint Sender.RACKState is what we expect. - if state.Sender.RACKState.XmitTime.Before(xmitTime) { - t.Fatalf("RACK transmit time failed to update when an ACK is received") - } - - gotSeq := state.Sender.RACKState.EndSequence - wantSeq := state.Sender.SndNxt - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } - - if state.Sender.RACKState.RTT == 0 { - t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") - } - close(probeDone) - }) - setStackSACKPermitted(t, c, true) - setStackRACKPermitted(t, c) - createConnectedWithSACKAndTS(c) - - data := make([]byte, maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - xmitTime = time.Now() - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -// TestRACKDetectReorder tests that RACK detects packet reordering. -func TestRACKDetectReorder(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - t.Skipf("Skipping this test as reorder detection does not consider DSACK.") - - var n int - const ackNumToVerify = 2 - probeDone := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - gotSeq := state.Sender.RACKState.FACK - wantSeq := state.Sender.SndNxt - // FACK should be updated to the highest ending sequence number of the - // segment acknowledged most recently. - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } - - n++ - if n < ackNumToVerify { - if state.Sender.RACKState.Reord { - t.Fatalf("RACK reorder detected when there is no reordering") - } - return - } - - if state.Sender.RACKState.Reord == false { - t.Fatalf("RACK reorder detection failed") - } - close(probeDone) - }) - setStackSACKPermitted(t, c, true) - setStackRACKPermitted(t, c) - createConnectedWithSACKAndTS(c) - data := make([]byte, ackNumToVerify*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < ackNumToVerify; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - start := c.IRS.Add(maxPayload + 1) - end := start.Add(maxPayload) - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte { - setStackSACKPermitted(t, c, true) - if enableRACK { - setStackRACKPermitted(t, c) - } - createConnectedWithSACKAndTS(c) - - data := make([]byte, numPackets*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < numPackets; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - return data -} - -const ( - validDSACKDetected = 1 - failedToDetectDSACK = 2 - invalidDSACKDetected = 3 -) - -func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) { - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK detects DSACK. - n++ - if n < numACK { - if state.Sender.RACKState.DSACKSeen { - probeDone <- invalidDSACKDetected - } - return - } - - if !state.Sender.RACKState.DSACKSeen { - probeDone <- failedToDetectDSACK - return - } - probeDone <- validDSACKDetected - }) -} - -// TestRACKTLPRecovery tests that RACK sends a tail loss probe (TLP) in the -// case of a tail loss. This simulates a situation where the TLP is able to -// insinuate the SACK holes and sender is able to retransmit the rest. -func TestRACKTLPRecovery(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - // Send the SACK after RTT because RACK RFC states that if the ACK for a - // retransmission arrives before the smoothed RTT then the sender should not - // update RACK state as it could be a spurious inference. - time.Sleep(info.RTT) - - // Okay, let the sender know we got #8 using a SACK block. - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // The sender should be entering RACK based loss-recovery and sending #6 and - // #7 one after another. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += 2 * maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // One fast retransmit after the SACK. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - // Recovery should be SACK recovery. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // Packets 6, 7 and 8 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 3}, - // TLP recovery should have been detected. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 1}, - // No RTOs should have occurred. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKTLPFallbackRTO tests that RACK sends a tail loss probe (TLP) in the -// case of a tail loss. This simulates a situation where either the TLP or its -// ACK is lost. The sender should retransmit when RTO fires. -func TestRACKTLPFallbackRTO(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - - // Either the TLP or the ACK the receiver sent with SACK blocks was lost. - - // Confirm that RTO fires and retransmits packet #6. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // No fast retransmits happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - // No SACK recovery happened. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - // TLP was unsuccessful. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestNoTLPRecoveryOnDSACK tests the scenario where the sender speculates a -// tail loss and sends a TLP. Everything is received and acked. The probe -// segment is DSACKed. No fast recovery should be triggered in this case. -func TestNoTLPRecoveryOnDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-5] are received first. [6-8] took a detour and will take a - // while to arrive. Ack [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // The tail loss probe (#8 packet) is received. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - - // Now that all 8 packets are received + duplicate 8th packet, send ack. - bytesRead += 3 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Wait for RTO and make sure that nothing else is received. - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - if p := c.GetPacketWithTimeout(info.RTO); p != nil { - t.Errorf("received an unexpected packet: %v", p) - } - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Make sure no recovery was entered. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #8 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestNoTLPOnSACK tests the scenario where there is not exactly a tail loss -// due to the presence of multiple SACK holes. In such a scenario, TLP should -// not be sent. -func TestNoTLPOnSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-5] and #7 were received. #6 and #8 were dropped. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhEnd := seventhStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhStart, seventhEnd}}) - - // The sender should retransmit #6. If the sender sends a TLP, then #8 will - // received and fail this test. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #6 was retransmitted due to SACK recovery. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #6 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKOnePacketTailLoss tests the trivial case of a tail loss of only one -// packet. The probe should itself repairs the loss instead of having to go -// into any recovery. -func TestRACKOnePacketTailLoss(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 3 packets. - numPackets := 3 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-2] are received. #3 is lost. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 2 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #3 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #3 was retransmitted as TLP. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #3 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.1. -func TestRACKDetectDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 2 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK #8 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Expect retransmission of #6 packet after RTO expires. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-8] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 3 * maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of -// order segments. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.2. -func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 2 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 10 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - // Send DSACK block for #6 along with SACK for out of - // order #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a -// duplicate of out of order packet. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.3 -func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 4 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 10 - sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // ACK [1-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // Send SACK indicating #6 packet is missing and received #7 packet. - offset := seqnum.Size(bytesRead + maxPayload) - start := c.IRS.Add(1 + offset) - end := start.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Send SACK with #6 packet is missing and received [7-8] packets. - end = start.Add(2 * maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Consider #8 packet is duplicated on the network and send DSACK. - dsackStart := c.IRS.Add(1 + offset + maxPayload) - dsackEnd := dsackStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.1. -func TestRACKDetectDSACKSingleDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 4 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 4 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-3] packets and received #4 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // ACK for retransmitted #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by - // sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous -// duplicate subsegments covered by the cumulative acknowledgement. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.2. -func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 5 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 6 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-5] packets and received #6 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Received delayed #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #4 packet. - start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 - // packet by sending DSACK block for #2 packet. - dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not -// covered by the cumulative acknowledgement. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.3. -func TestRACKDetectDSACKDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 5 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 7 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #3 packet. - start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Consider #2 packet has been dropped and SACK #4 packet. - start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end2 := start2.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 - // packet by sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - end1 = end1.Add(seqnum.Size(2 * maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK -// is not the first SACK block. -func TestRACKWithInvalidDSACKBlock(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan struct{}) - const ackNumToVerify = 2 - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK does not detect DSACK when DSACK block is - // not the first SACK block. - n++ - t.Helper() - if state.Sender.RACKState.DSACKSeen { - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } - - if n == ackNumToVerify { - close(probeDone) - } - }) - - numPackets := 10 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - - // Send DSACK block as second block. The first block is a SACK for #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - <-probeDone -} - -func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) { - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK detects DSACK. - n++ - if n < numACK { - return - } - - if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { - probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) - return - } - - if state.Sender.RACKState.ReoWndIncr != 1 { - probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr) - return - } - - if state.Sender.RACKState.ReoWndPersist > 0 { - probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist) - return - } - probeDone <- nil - }) -} - -func TestRACKCheckReorderWindow(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan error) - const ackNumToVerify = 3 - addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) - - const numPackets = 7 - sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed packets [2-6] which indicates there is reordering - // in the connection. - bytesRead += 6 * maxPayload - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - if err := <-probeDone; err != nil { - t.Fatalf("unexpected values for RACK variables: %v", err) - } -} - -func TestRACKWithDuplicateACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - const numPackets = 4 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send three duplicate ACKs to trigger fast recovery. The first - // segment is considered as lost and will be retransmitted after - // receiving the duplicate ACKs. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - end = end.Add(seqnum.Size(maxPayload)) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK -// is received. -func TestRACKUpdateSackedOut(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan struct{}) - ackNum := 0 - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint Sender.SackedOut is what we expect. - if state.Sender.SackedOut != 2 && ackNum == 0 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) - } - - if state.Sender.SackedOut != 0 && ackNum == 1 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) - } - if ackNum > 0 { - close(probeDone) - } - ackNum++ - }) - - sendAndReceiveWithSACK(t, c, 8, true /* enableRACK */) - - // ACK for [3-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) - bytesRead := 2 * maxPayload - end := start.Add(seqnum.Size(bytesRead)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - bytesRead += 3 * maxPayload - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go deleted file mode 100644 index 81f800cad..000000000 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ /dev/null @@ -1,705 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "log" - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -// createConnectedWithSACKPermittedOption creates and connects c.ep with the -// SACKPermitted option enabled if the stack in the context has the SACK support -// enabled. -func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) -} - -// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS -// option enabled if the stack in the context has SACK and TS enabled. -func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) -} - -func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { - t.Helper() - opt := tcpip.TCPSACKEnabled(enable) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } -} - -// TestSackPermittedConnect establishes a connection with the SACK option -// enabled. -func TestSackPermittedConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - rep := createConnectedWithSACKPermittedOption(c) - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // Restore the saved sequence number so that the - // VerifyXXX calls use the right sequence number for - // checking ACK numbers. - rep.NextSeqNum = savedSeqNum - if sackEnabled { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackDisabledConnect establishes a connection with the SACK option -// disabled and verifies that no SACKs are sent for out of order segments. -func TestSackDisabledConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) - - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older sequence number and - // no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackPermittedAccept accepts and establishes a connection with the -// SACKPermitted option enabled if the connection request specifies the -// SACKPermitted option. In case of SYN cookies SACK should be disabled as we -// don't encode the SACK information in the cookie. -func TestSackPermittedAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - sackPermitted bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number. - rep.NextSeqNum = savedSeqNum - if sackEnabled && tc.sackPermitted { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -// TestSackDisabledAccept accepts and establishes a connection with -// the SACKPermitted option disabled and verifies that no SACKs are -// sent for out of order packets. -func TestSackDisabledAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number and no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -func TestUpdateSACKBlocks(t *testing.T) { - testCases := []struct { - segStart seqnum.Value - segEnd seqnum.Value - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - updated []header.SACKBlock - }{ - // Trivial cases where current SACK block list is empty and we - // have an out of order delivery. - {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, - {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, - {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, - - // Cases where current SACK block list is not empty and we have - // an out of order delivery. Tests that the updated SACK block - // list has the first block as the one that contains the new - // SACK block representing the segment that was just delivered. - {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, - - // Ensure that we only retain header.MaxSACKBlocks and drop the - // oldest one if adding a new block exceeds - // header.MaxSACKBlocks. - {24, 30, 9, - []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, - []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, - - // Cases where segment extends an existing SACK block. - {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, - {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, - - // Cases where segment contains rcvNxt. - {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, - } - - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { - t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) - } - - } -} - -func TestTrimSackBlockList(t *testing.T) { - testCases := []struct { - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - trimmed []header.SACKBlock - }{ - // Simple cases where we trim whole entries. - {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, - {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, - {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, - {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - // Cases where we need to update a block. - {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, - {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, - {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, - {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - } - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.TrimSACKBlockList(&sack, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { - t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) - } - } -} - -func TestSACKRecovery(t *testing.T) { - const maxPayload = 10 - // See: tcp.makeOptions for why tsOptionSize is set to 12 here. - const tsOptionSize = 12 - // Enabling SACK means the payload size is reduced to account - // for the extra space required for the TCP options. - // - // We increase the MTU by 40 bytes to account for SACK and Timestamp - // options. - const maxTCPOptionSize = 40 - - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) - defer c.Cleanup() - - c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { - // We use log.Printf instead of t.Logf here because this probe - // can fire even when the test function has finished. This is - // because closing the endpoint in cleanup() does not mean the - // actual worker loop terminates immediately as it still has to - // do a full TCP shutdown. But this test can finish running - // before the shutdown is done. Using t.Logf in such a case - // causes the test to panic due to logging after test finished. - log.Printf("state: %+v\n", s) - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(seq, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end := start.Add(10) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause - // window inflation and sending of packets is completely handled by the - // SACK Recovery algorithm. We should see no packets being released, as - // the cwnd at this point after entering recovery should be half of the - // outstanding number of packets in flight. - for i := 0; i < 7; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. This along with the 10 sacked - // segments above should reduce the outstanding below the current - // congestion window allowing the sender to transmit data. - rtxOffset = bytesRead - expected*maxPayload/2 - - // Now send a partial ACK w/ a SACK block that indicates that the next 3 - // segments are lost and we have received 6 segments after the lost - // segments. This should cause the sender to immediately transmit all 3 - // segments in response to this ACK unlike in FastRecovery where only 1 - // segment is retransmitted per ACK. - start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end = start.Add(60) - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - - // At this point, we acked expected/2 packets and we SACKED 6 packets and - // 3 segments were considered lost due to the SACK block we sent. - // - // So total packets outstanding can be calculated as follows after 7 - // iterations of slow start -> 10/20/40/80/160/320/640. So expected - // should be 640 at start, then we went to recover at which point the - // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the - // network). - // Outstanding at this point after acking half the window - // (320 packets) will be: - // outstanding = 640-320-6(due to SACK block)-3 = 311 - // - // The last 3 is due to the fact that the first 3 packets after - // rtxOffset will be considered lost due to the SACK blocks sent. - // Receive the retransmit due to partial ack. - - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - // Receive the 2 extra packets that should have been retransmitted as - // those should be considered lost and immediately retransmitted based - // on the SACK information in the previous ACK sent above. - for i := 0; i < 2; i++ { - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) - } - - // Now we should get 9 more new unsent packets as the cwnd is 323 and - // outstanding is 311. - for i := 0; i < 9; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - metricPollFn = func() error { - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(seq, recover) - - // At this point, the cwnd should reset to expected/2 and there are 9 - // packets outstanding. - // - // Now in the first iteration since there are 9 packets outstanding. - // We would expect to get expected/2 - 9 packets. But subsequent - // iterations will send us expected/2 + 1 (per iteration). - expected = expected/2 - 9 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(seq, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 9 - } - expected++ - } -} - -// TestRecoveryEntry tests the following two properties of entering recovery: -// - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK -// scoreboard but dupack count is still below threshold. -// - Only enter recovery when at least one more byte of data beyond the highest -// byte that was outstanding when fast retransmit was last entered is acked. -func TestRecoveryEntry(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - numPackets := 5 - data := sendAndReceiveWithSACK(t, c, numPackets, false /* enableRACK */) - - // Ack #1 packet. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, maxPayload) - - // Now SACK #3, #4 and #5 packets. This will simulate a situation where - // SND.UNA should be considered lost and the sender should enter fast recovery - // (even though dupack count is still below threshold). - p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - p3End := p3Start.Add(maxPayload) - p4Start := p3End - p4End := p4Start.Add(maxPayload) - p5Start := p4End - p5End := p5Start.Add(maxPayload) - c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}}) - - // Expect #2 to be retransmitted. - c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - // No RTOs should have fired yet. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Send 4 more packets. - var r bytes.Reader - data = append(data, data...) - r.Reset(data[5*maxPayload : 9*maxPayload]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - var sackBlocks []header.SACKBlock - bytesRead := numPackets * maxPayload - for i := 0; i < 4; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - if i > 0 { - pStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)}) - c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks) - } - bytesRead += maxPayload - } - - // #6 should be retransmitted after RTO. The sender should NOT enter fast - // recovery because the highest byte that was outstanding when fast recovery - // was last entered is #5 packet's end. And the sender requires at least one - // more byte beyond that (#6 packet start) to be acked to enter recovery. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) - c.SendAck(seq, 9*maxPayload) - - metricPollFn = func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Only 1 SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 and #6 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 2}, - // RTO should have fired once. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go new file mode 100644 index 000000000..a14cff27e --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,221 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *segmentList) Back() *segment { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *segmentList) Len() (count int) { + for e := l.Front(); e != nil; e = (segmentElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *segmentList) PushFront(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *segmentList) PushBack(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *segmentList) InsertAfter(b, e *segment) { + bLinker := segmentElementMapper{}.linkerFor(b) + eLinker := segmentElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *segmentList) InsertBefore(a, e *segment) { + aLinker := segmentElementMapper{}.linkerFor(a) + eLinker := segmentElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *segmentList) Remove(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go new file mode 100644 index 000000000..4f3f62b98 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -0,0 +1,1070 @@ +// automatically generated by stateify. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (c *cubicState) StateTypeName() string { + return "pkg/tcpip/transport/tcp.cubicState" +} + +func (c *cubicState) StateFields() []string { + return []string{ + "wLastMax", + "wMax", + "t", + "numCongestionEvents", + "c", + "k", + "beta", + "wC", + "wEst", + "s", + } +} + +func (c *cubicState) beforeSave() {} + +func (c *cubicState) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + var tValue unixTime = c.saveT() + stateSinkObject.SaveValue(2, tValue) + stateSinkObject.Save(0, &c.wLastMax) + stateSinkObject.Save(1, &c.wMax) + stateSinkObject.Save(3, &c.numCongestionEvents) + stateSinkObject.Save(4, &c.c) + stateSinkObject.Save(5, &c.k) + stateSinkObject.Save(6, &c.beta) + stateSinkObject.Save(7, &c.wC) + stateSinkObject.Save(8, &c.wEst) + stateSinkObject.Save(9, &c.s) +} + +func (c *cubicState) afterLoad() {} + +func (c *cubicState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.wLastMax) + stateSourceObject.Load(1, &c.wMax) + stateSourceObject.Load(3, &c.numCongestionEvents) + stateSourceObject.Load(4, &c.c) + stateSourceObject.Load(5, &c.k) + stateSourceObject.Load(6, &c.beta) + stateSourceObject.Load(7, &c.wC) + stateSourceObject.Load(8, &c.wEst) + stateSourceObject.Load(9, &c.s) + stateSourceObject.LoadValue(2, new(unixTime), func(y interface{}) { c.loadT(y.(unixTime)) }) +} + +func (s *SACKInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.SACKInfo" +} + +func (s *SACKInfo) StateFields() []string { + return []string{ + "Blocks", + "NumBlocks", + } +} + +func (s *SACKInfo) beforeSave() {} + +func (s *SACKInfo) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Blocks) + stateSinkObject.Save(1, &s.NumBlocks) +} + +func (s *SACKInfo) afterLoad() {} + +func (s *SACKInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Blocks) + stateSourceObject.Load(1, &s.NumBlocks) +} + +func (r *rcvBufAutoTuneParams) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rcvBufAutoTuneParams" +} + +func (r *rcvBufAutoTuneParams) StateFields() []string { + return []string{ + "measureTime", + "copied", + "prevCopied", + "rtt", + "rttMeasureSeqNumber", + "rttMeasureTime", + "disabled", + } +} + +func (r *rcvBufAutoTuneParams) beforeSave() {} + +func (r *rcvBufAutoTuneParams) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + var measureTimeValue unixTime = r.saveMeasureTime() + stateSinkObject.SaveValue(0, measureTimeValue) + var rttMeasureTimeValue unixTime = r.saveRttMeasureTime() + stateSinkObject.SaveValue(5, rttMeasureTimeValue) + stateSinkObject.Save(1, &r.copied) + stateSinkObject.Save(2, &r.prevCopied) + stateSinkObject.Save(3, &r.rtt) + stateSinkObject.Save(4, &r.rttMeasureSeqNumber) + stateSinkObject.Save(6, &r.disabled) +} + +func (r *rcvBufAutoTuneParams) afterLoad() {} + +func (r *rcvBufAutoTuneParams) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(1, &r.copied) + stateSourceObject.Load(2, &r.prevCopied) + stateSourceObject.Load(3, &r.rtt) + stateSourceObject.Load(4, &r.rttMeasureSeqNumber) + stateSourceObject.Load(6, &r.disabled) + stateSourceObject.LoadValue(0, new(unixTime), func(y interface{}) { r.loadMeasureTime(y.(unixTime)) }) + stateSourceObject.LoadValue(5, new(unixTime), func(y interface{}) { r.loadRttMeasureTime(y.(unixTime)) }) +} + +func (e *EndpointInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.EndpointInfo" +} + +func (e *EndpointInfo) StateFields() []string { + return []string{ + "TransportEndpointInfo", + } +} + +func (e *EndpointInfo) beforeSave() {} + +func (e *EndpointInfo) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.TransportEndpointInfo) +} + +func (e *EndpointInfo) afterLoad() {} + +func (e *EndpointInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "EndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "hardError", + "lastError", + "rcvList", + "rcvClosed", + "rcvBufSize", + "rcvBufUsed", + "rcvAutoParams", + "rcvMemUsed", + "ownedByUser", + "state", + "boundNICID", + "ttl", + "isConnectNotified", + "portFlags", + "boundBindToDevice", + "boundPortFlags", + "boundDest", + "effectiveNetProtos", + "workerRunning", + "workerCleanup", + "sendTSOk", + "recentTS", + "recentTSTime", + "tsOffset", + "shutdownFlags", + "tcpRecovery", + "sackPermitted", + "sack", + "delay", + "scoreboard", + "segmentQueue", + "synRcvdCount", + "userMSS", + "maxSynRetries", + "windowClamp", + "sndBufUsed", + "sndClosed", + "sndBufInQueue", + "sndQueue", + "cc", + "packetTooBigCount", + "sndMTU", + "keepalive", + "userTimeout", + "deferAccept", + "acceptedChan", + "rcv", + "snd", + "connectingAddress", + "amss", + "sendTOS", + "gso", + "tcpLingerTimeout", + "closed", + "txHash", + "owner", + "ops", + "lastOutOfWindowAckTime", + } +} + +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var stateValue EndpointState = e.saveState() + stateSinkObject.SaveValue(13, stateValue) + var recentTSTimeValue unixTime = e.saveRecentTSTime() + stateSinkObject.SaveValue(26, recentTSTimeValue) + var acceptedChanValue []*endpoint = e.saveAcceptedChan() + stateSinkObject.SaveValue(49, acceptedChanValue) + var lastOutOfWindowAckTimeValue unixTime = e.saveLastOutOfWindowAckTime() + stateSinkObject.SaveValue(61, lastOutOfWindowAckTimeValue) + stateSinkObject.Save(0, &e.EndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.uniqueID) + stateSinkObject.Save(4, &e.hardError) + stateSinkObject.Save(5, &e.lastError) + stateSinkObject.Save(6, &e.rcvList) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.rcvBufSize) + stateSinkObject.Save(9, &e.rcvBufUsed) + stateSinkObject.Save(10, &e.rcvAutoParams) + stateSinkObject.Save(11, &e.rcvMemUsed) + stateSinkObject.Save(12, &e.ownedByUser) + stateSinkObject.Save(14, &e.boundNICID) + stateSinkObject.Save(15, &e.ttl) + stateSinkObject.Save(16, &e.isConnectNotified) + stateSinkObject.Save(17, &e.portFlags) + stateSinkObject.Save(18, &e.boundBindToDevice) + stateSinkObject.Save(19, &e.boundPortFlags) + stateSinkObject.Save(20, &e.boundDest) + stateSinkObject.Save(21, &e.effectiveNetProtos) + stateSinkObject.Save(22, &e.workerRunning) + stateSinkObject.Save(23, &e.workerCleanup) + stateSinkObject.Save(24, &e.sendTSOk) + stateSinkObject.Save(25, &e.recentTS) + stateSinkObject.Save(27, &e.tsOffset) + stateSinkObject.Save(28, &e.shutdownFlags) + stateSinkObject.Save(29, &e.tcpRecovery) + stateSinkObject.Save(30, &e.sackPermitted) + stateSinkObject.Save(31, &e.sack) + stateSinkObject.Save(32, &e.delay) + stateSinkObject.Save(33, &e.scoreboard) + stateSinkObject.Save(34, &e.segmentQueue) + stateSinkObject.Save(35, &e.synRcvdCount) + stateSinkObject.Save(36, &e.userMSS) + stateSinkObject.Save(37, &e.maxSynRetries) + stateSinkObject.Save(38, &e.windowClamp) + stateSinkObject.Save(39, &e.sndBufUsed) + stateSinkObject.Save(40, &e.sndClosed) + stateSinkObject.Save(41, &e.sndBufInQueue) + stateSinkObject.Save(42, &e.sndQueue) + stateSinkObject.Save(43, &e.cc) + stateSinkObject.Save(44, &e.packetTooBigCount) + stateSinkObject.Save(45, &e.sndMTU) + stateSinkObject.Save(46, &e.keepalive) + stateSinkObject.Save(47, &e.userTimeout) + stateSinkObject.Save(48, &e.deferAccept) + stateSinkObject.Save(50, &e.rcv) + stateSinkObject.Save(51, &e.snd) + stateSinkObject.Save(52, &e.connectingAddress) + stateSinkObject.Save(53, &e.amss) + stateSinkObject.Save(54, &e.sendTOS) + stateSinkObject.Save(55, &e.gso) + stateSinkObject.Save(56, &e.tcpLingerTimeout) + stateSinkObject.Save(57, &e.closed) + stateSinkObject.Save(58, &e.txHash) + stateSinkObject.Save(59, &e.owner) + stateSinkObject.Save(60, &e.ops) +} + +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.EndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.LoadWait(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.uniqueID) + stateSourceObject.Load(4, &e.hardError) + stateSourceObject.Load(5, &e.lastError) + stateSourceObject.LoadWait(6, &e.rcvList) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.rcvBufSize) + stateSourceObject.Load(9, &e.rcvBufUsed) + stateSourceObject.Load(10, &e.rcvAutoParams) + stateSourceObject.Load(11, &e.rcvMemUsed) + stateSourceObject.Load(12, &e.ownedByUser) + stateSourceObject.Load(14, &e.boundNICID) + stateSourceObject.Load(15, &e.ttl) + stateSourceObject.Load(16, &e.isConnectNotified) + stateSourceObject.Load(17, &e.portFlags) + stateSourceObject.Load(18, &e.boundBindToDevice) + stateSourceObject.Load(19, &e.boundPortFlags) + stateSourceObject.Load(20, &e.boundDest) + stateSourceObject.Load(21, &e.effectiveNetProtos) + stateSourceObject.Load(22, &e.workerRunning) + stateSourceObject.Load(23, &e.workerCleanup) + stateSourceObject.Load(24, &e.sendTSOk) + stateSourceObject.Load(25, &e.recentTS) + stateSourceObject.Load(27, &e.tsOffset) + stateSourceObject.Load(28, &e.shutdownFlags) + stateSourceObject.Load(29, &e.tcpRecovery) + stateSourceObject.Load(30, &e.sackPermitted) + stateSourceObject.Load(31, &e.sack) + stateSourceObject.Load(32, &e.delay) + stateSourceObject.Load(33, &e.scoreboard) + stateSourceObject.LoadWait(34, &e.segmentQueue) + stateSourceObject.Load(35, &e.synRcvdCount) + stateSourceObject.Load(36, &e.userMSS) + stateSourceObject.Load(37, &e.maxSynRetries) + stateSourceObject.Load(38, &e.windowClamp) + stateSourceObject.Load(39, &e.sndBufUsed) + stateSourceObject.Load(40, &e.sndClosed) + stateSourceObject.Load(41, &e.sndBufInQueue) + stateSourceObject.LoadWait(42, &e.sndQueue) + stateSourceObject.Load(43, &e.cc) + stateSourceObject.Load(44, &e.packetTooBigCount) + stateSourceObject.Load(45, &e.sndMTU) + stateSourceObject.Load(46, &e.keepalive) + stateSourceObject.Load(47, &e.userTimeout) + stateSourceObject.Load(48, &e.deferAccept) + stateSourceObject.LoadWait(50, &e.rcv) + stateSourceObject.LoadWait(51, &e.snd) + stateSourceObject.Load(52, &e.connectingAddress) + stateSourceObject.Load(53, &e.amss) + stateSourceObject.Load(54, &e.sendTOS) + stateSourceObject.Load(55, &e.gso) + stateSourceObject.Load(56, &e.tcpLingerTimeout) + stateSourceObject.Load(57, &e.closed) + stateSourceObject.Load(58, &e.txHash) + stateSourceObject.Load(59, &e.owner) + stateSourceObject.Load(60, &e.ops) + stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) + stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) }) + stateSourceObject.LoadValue(49, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) + stateSourceObject.LoadValue(61, new(unixTime), func(y interface{}) { e.loadLastOutOfWindowAckTime(y.(unixTime)) }) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (k *keepalive) StateTypeName() string { + return "pkg/tcpip/transport/tcp.keepalive" +} + +func (k *keepalive) StateFields() []string { + return []string{ + "idle", + "interval", + "count", + "unacked", + } +} + +func (k *keepalive) beforeSave() {} + +func (k *keepalive) StateSave(stateSinkObject state.Sink) { + k.beforeSave() + stateSinkObject.Save(0, &k.idle) + stateSinkObject.Save(1, &k.interval) + stateSinkObject.Save(2, &k.count) + stateSinkObject.Save(3, &k.unacked) +} + +func (k *keepalive) afterLoad() {} + +func (k *keepalive) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &k.idle) + stateSourceObject.Load(1, &k.interval) + stateSourceObject.Load(2, &k.count) + stateSourceObject.Load(3, &k.unacked) +} + +func (rc *rackControl) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rackControl" +} + +func (rc *rackControl) StateFields() []string { + return []string{ + "dsackSeen", + "endSequence", + "exitedRecovery", + "fack", + "minRTT", + "reorderSeen", + "reoWnd", + "reoWndIncr", + "reoWndPersist", + "rtt", + "rttSeq", + "xmitTime", + "tlpRxtOut", + "tlpHighRxt", + "snd", + } +} + +func (rc *rackControl) beforeSave() {} + +func (rc *rackControl) StateSave(stateSinkObject state.Sink) { + rc.beforeSave() + var xmitTimeValue unixTime = rc.saveXmitTime() + stateSinkObject.SaveValue(11, xmitTimeValue) + stateSinkObject.Save(0, &rc.dsackSeen) + stateSinkObject.Save(1, &rc.endSequence) + stateSinkObject.Save(2, &rc.exitedRecovery) + stateSinkObject.Save(3, &rc.fack) + stateSinkObject.Save(4, &rc.minRTT) + stateSinkObject.Save(5, &rc.reorderSeen) + stateSinkObject.Save(6, &rc.reoWnd) + stateSinkObject.Save(7, &rc.reoWndIncr) + stateSinkObject.Save(8, &rc.reoWndPersist) + stateSinkObject.Save(9, &rc.rtt) + stateSinkObject.Save(10, &rc.rttSeq) + stateSinkObject.Save(12, &rc.tlpRxtOut) + stateSinkObject.Save(13, &rc.tlpHighRxt) + stateSinkObject.Save(14, &rc.snd) +} + +func (rc *rackControl) afterLoad() {} + +func (rc *rackControl) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rc.dsackSeen) + stateSourceObject.Load(1, &rc.endSequence) + stateSourceObject.Load(2, &rc.exitedRecovery) + stateSourceObject.Load(3, &rc.fack) + stateSourceObject.Load(4, &rc.minRTT) + stateSourceObject.Load(5, &rc.reorderSeen) + stateSourceObject.Load(6, &rc.reoWnd) + stateSourceObject.Load(7, &rc.reoWndIncr) + stateSourceObject.Load(8, &rc.reoWndPersist) + stateSourceObject.Load(9, &rc.rtt) + stateSourceObject.Load(10, &rc.rttSeq) + stateSourceObject.Load(12, &rc.tlpRxtOut) + stateSourceObject.Load(13, &rc.tlpHighRxt) + stateSourceObject.Load(14, &rc.snd) + stateSourceObject.LoadValue(11, new(unixTime), func(y interface{}) { rc.loadXmitTime(y.(unixTime)) }) +} + +func (r *receiver) StateTypeName() string { + return "pkg/tcpip/transport/tcp.receiver" +} + +func (r *receiver) StateFields() []string { + return []string{ + "ep", + "rcvNxt", + "rcvAcc", + "rcvWnd", + "rcvWUP", + "rcvWndScale", + "prevBufUsed", + "closed", + "pendingRcvdSegments", + "pendingBufUsed", + "lastRcvdAckTime", + } +} + +func (r *receiver) beforeSave() {} + +func (r *receiver) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + var lastRcvdAckTimeValue unixTime = r.saveLastRcvdAckTime() + stateSinkObject.SaveValue(10, lastRcvdAckTimeValue) + stateSinkObject.Save(0, &r.ep) + stateSinkObject.Save(1, &r.rcvNxt) + stateSinkObject.Save(2, &r.rcvAcc) + stateSinkObject.Save(3, &r.rcvWnd) + stateSinkObject.Save(4, &r.rcvWUP) + stateSinkObject.Save(5, &r.rcvWndScale) + stateSinkObject.Save(6, &r.prevBufUsed) + stateSinkObject.Save(7, &r.closed) + stateSinkObject.Save(8, &r.pendingRcvdSegments) + stateSinkObject.Save(9, &r.pendingBufUsed) +} + +func (r *receiver) afterLoad() {} + +func (r *receiver) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.ep) + stateSourceObject.Load(1, &r.rcvNxt) + stateSourceObject.Load(2, &r.rcvAcc) + stateSourceObject.Load(3, &r.rcvWnd) + stateSourceObject.Load(4, &r.rcvWUP) + stateSourceObject.Load(5, &r.rcvWndScale) + stateSourceObject.Load(6, &r.prevBufUsed) + stateSourceObject.Load(7, &r.closed) + stateSourceObject.Load(8, &r.pendingRcvdSegments) + stateSourceObject.Load(9, &r.pendingBufUsed) + stateSourceObject.LoadValue(10, new(unixTime), func(y interface{}) { r.loadLastRcvdAckTime(y.(unixTime)) }) +} + +func (r *renoState) StateTypeName() string { + return "pkg/tcpip/transport/tcp.renoState" +} + +func (r *renoState) StateFields() []string { + return []string{ + "s", + } +} + +func (r *renoState) beforeSave() {} + +func (r *renoState) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.s) +} + +func (r *renoState) afterLoad() {} + +func (r *renoState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.s) +} + +func (rr *renoRecovery) StateTypeName() string { + return "pkg/tcpip/transport/tcp.renoRecovery" +} + +func (rr *renoRecovery) StateFields() []string { + return []string{ + "s", + } +} + +func (rr *renoRecovery) beforeSave() {} + +func (rr *renoRecovery) StateSave(stateSinkObject state.Sink) { + rr.beforeSave() + stateSinkObject.Save(0, &rr.s) +} + +func (rr *renoRecovery) afterLoad() {} + +func (rr *renoRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rr.s) +} + +func (sr *sackRecovery) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sackRecovery" +} + +func (sr *sackRecovery) StateFields() []string { + return []string{ + "s", + } +} + +func (sr *sackRecovery) beforeSave() {} + +func (sr *sackRecovery) StateSave(stateSinkObject state.Sink) { + sr.beforeSave() + stateSinkObject.Save(0, &sr.s) +} + +func (sr *sackRecovery) afterLoad() {} + +func (sr *sackRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sr.s) +} + +func (s *SACKScoreboard) StateTypeName() string { + return "pkg/tcpip/transport/tcp.SACKScoreboard" +} + +func (s *SACKScoreboard) StateFields() []string { + return []string{ + "smss", + "maxSACKED", + } +} + +func (s *SACKScoreboard) beforeSave() {} + +func (s *SACKScoreboard) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.smss) + stateSinkObject.Save(1, &s.maxSACKED) +} + +func (s *SACKScoreboard) afterLoad() {} + +func (s *SACKScoreboard) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.smss) + stateSourceObject.Load(1, &s.maxSACKED) +} + +func (s *segment) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segment" +} + +func (s *segment) StateFields() []string { + return []string{ + "segmentEntry", + "refCnt", + "ep", + "qFlags", + "srcAddr", + "dstAddr", + "netProto", + "nicID", + "data", + "hdr", + "sequenceNumber", + "ackNumber", + "flags", + "window", + "csum", + "csumValid", + "parsedOptions", + "options", + "hasNewSACKInfo", + "rcvdTime", + "xmitTime", + "xmitCount", + "acked", + "dataMemSize", + "lost", + } +} + +func (s *segment) beforeSave() {} + +func (s *segment) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var dataValue buffer.VectorisedView = s.saveData() + stateSinkObject.SaveValue(8, dataValue) + var optionsValue []byte = s.saveOptions() + stateSinkObject.SaveValue(17, optionsValue) + var rcvdTimeValue unixTime = s.saveRcvdTime() + stateSinkObject.SaveValue(19, rcvdTimeValue) + var xmitTimeValue unixTime = s.saveXmitTime() + stateSinkObject.SaveValue(20, xmitTimeValue) + stateSinkObject.Save(0, &s.segmentEntry) + stateSinkObject.Save(1, &s.refCnt) + stateSinkObject.Save(2, &s.ep) + stateSinkObject.Save(3, &s.qFlags) + stateSinkObject.Save(4, &s.srcAddr) + stateSinkObject.Save(5, &s.dstAddr) + stateSinkObject.Save(6, &s.netProto) + stateSinkObject.Save(7, &s.nicID) + stateSinkObject.Save(9, &s.hdr) + stateSinkObject.Save(10, &s.sequenceNumber) + stateSinkObject.Save(11, &s.ackNumber) + stateSinkObject.Save(12, &s.flags) + stateSinkObject.Save(13, &s.window) + stateSinkObject.Save(14, &s.csum) + stateSinkObject.Save(15, &s.csumValid) + stateSinkObject.Save(16, &s.parsedOptions) + stateSinkObject.Save(18, &s.hasNewSACKInfo) + stateSinkObject.Save(21, &s.xmitCount) + stateSinkObject.Save(22, &s.acked) + stateSinkObject.Save(23, &s.dataMemSize) + stateSinkObject.Save(24, &s.lost) +} + +func (s *segment) afterLoad() {} + +func (s *segment) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.segmentEntry) + stateSourceObject.Load(1, &s.refCnt) + stateSourceObject.Load(2, &s.ep) + stateSourceObject.Load(3, &s.qFlags) + stateSourceObject.Load(4, &s.srcAddr) + stateSourceObject.Load(5, &s.dstAddr) + stateSourceObject.Load(6, &s.netProto) + stateSourceObject.Load(7, &s.nicID) + stateSourceObject.Load(9, &s.hdr) + stateSourceObject.Load(10, &s.sequenceNumber) + stateSourceObject.Load(11, &s.ackNumber) + stateSourceObject.Load(12, &s.flags) + stateSourceObject.Load(13, &s.window) + stateSourceObject.Load(14, &s.csum) + stateSourceObject.Load(15, &s.csumValid) + stateSourceObject.Load(16, &s.parsedOptions) + stateSourceObject.Load(18, &s.hasNewSACKInfo) + stateSourceObject.Load(21, &s.xmitCount) + stateSourceObject.Load(22, &s.acked) + stateSourceObject.Load(23, &s.dataMemSize) + stateSourceObject.Load(24, &s.lost) + stateSourceObject.LoadValue(8, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(17, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) }) + stateSourceObject.LoadValue(19, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) }) + stateSourceObject.LoadValue(20, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) }) +} + +func (q *segmentQueue) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentQueue" +} + +func (q *segmentQueue) StateFields() []string { + return []string{ + "list", + "ep", + "frozen", + } +} + +func (q *segmentQueue) beforeSave() {} + +func (q *segmentQueue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.list) + stateSinkObject.Save(1, &q.ep) + stateSinkObject.Save(2, &q.frozen) +} + +func (q *segmentQueue) afterLoad() {} + +func (q *segmentQueue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &q.list) + stateSourceObject.Load(1, &q.ep) + stateSourceObject.Load(2, &q.frozen) +} + +func (s *sender) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sender" +} + +func (s *sender) StateFields() []string { + return []string{ + "ep", + "lastSendTime", + "dupAckCount", + "fr", + "lr", + "sndCwnd", + "sndSsthresh", + "sndCAAckCount", + "outstanding", + "sackedOut", + "sndWnd", + "sndUna", + "sndNxt", + "rttMeasureSeqNum", + "rttMeasureTime", + "firstRetransmittedSegXmitTime", + "closed", + "writeNext", + "writeList", + "rtt", + "rto", + "minRTO", + "maxRTO", + "maxRetries", + "maxPayloadSize", + "gso", + "sndWndScale", + "maxSentAck", + "state", + "cc", + "rc", + } +} + +func (s *sender) beforeSave() {} + +func (s *sender) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var lastSendTimeValue unixTime = s.saveLastSendTime() + stateSinkObject.SaveValue(1, lastSendTimeValue) + var rttMeasureTimeValue unixTime = s.saveRttMeasureTime() + stateSinkObject.SaveValue(14, rttMeasureTimeValue) + var firstRetransmittedSegXmitTimeValue unixTime = s.saveFirstRetransmittedSegXmitTime() + stateSinkObject.SaveValue(15, firstRetransmittedSegXmitTimeValue) + stateSinkObject.Save(0, &s.ep) + stateSinkObject.Save(2, &s.dupAckCount) + stateSinkObject.Save(3, &s.fr) + stateSinkObject.Save(4, &s.lr) + stateSinkObject.Save(5, &s.sndCwnd) + stateSinkObject.Save(6, &s.sndSsthresh) + stateSinkObject.Save(7, &s.sndCAAckCount) + stateSinkObject.Save(8, &s.outstanding) + stateSinkObject.Save(9, &s.sackedOut) + stateSinkObject.Save(10, &s.sndWnd) + stateSinkObject.Save(11, &s.sndUna) + stateSinkObject.Save(12, &s.sndNxt) + stateSinkObject.Save(13, &s.rttMeasureSeqNum) + stateSinkObject.Save(16, &s.closed) + stateSinkObject.Save(17, &s.writeNext) + stateSinkObject.Save(18, &s.writeList) + stateSinkObject.Save(19, &s.rtt) + stateSinkObject.Save(20, &s.rto) + stateSinkObject.Save(21, &s.minRTO) + stateSinkObject.Save(22, &s.maxRTO) + stateSinkObject.Save(23, &s.maxRetries) + stateSinkObject.Save(24, &s.maxPayloadSize) + stateSinkObject.Save(25, &s.gso) + stateSinkObject.Save(26, &s.sndWndScale) + stateSinkObject.Save(27, &s.maxSentAck) + stateSinkObject.Save(28, &s.state) + stateSinkObject.Save(29, &s.cc) + stateSinkObject.Save(30, &s.rc) +} + +func (s *sender) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.ep) + stateSourceObject.Load(2, &s.dupAckCount) + stateSourceObject.Load(3, &s.fr) + stateSourceObject.Load(4, &s.lr) + stateSourceObject.Load(5, &s.sndCwnd) + stateSourceObject.Load(6, &s.sndSsthresh) + stateSourceObject.Load(7, &s.sndCAAckCount) + stateSourceObject.Load(8, &s.outstanding) + stateSourceObject.Load(9, &s.sackedOut) + stateSourceObject.Load(10, &s.sndWnd) + stateSourceObject.Load(11, &s.sndUna) + stateSourceObject.Load(12, &s.sndNxt) + stateSourceObject.Load(13, &s.rttMeasureSeqNum) + stateSourceObject.Load(16, &s.closed) + stateSourceObject.Load(17, &s.writeNext) + stateSourceObject.Load(18, &s.writeList) + stateSourceObject.Load(19, &s.rtt) + stateSourceObject.Load(20, &s.rto) + stateSourceObject.Load(21, &s.minRTO) + stateSourceObject.Load(22, &s.maxRTO) + stateSourceObject.Load(23, &s.maxRetries) + stateSourceObject.Load(24, &s.maxPayloadSize) + stateSourceObject.Load(25, &s.gso) + stateSourceObject.Load(26, &s.sndWndScale) + stateSourceObject.Load(27, &s.maxSentAck) + stateSourceObject.Load(28, &s.state) + stateSourceObject.Load(29, &s.cc) + stateSourceObject.Load(30, &s.rc) + stateSourceObject.LoadValue(1, new(unixTime), func(y interface{}) { s.loadLastSendTime(y.(unixTime)) }) + stateSourceObject.LoadValue(14, new(unixTime), func(y interface{}) { s.loadRttMeasureTime(y.(unixTime)) }) + stateSourceObject.LoadValue(15, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) }) + stateSourceObject.AfterLoad(s.afterLoad) +} + +func (r *rtt) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rtt" +} + +func (r *rtt) StateFields() []string { + return []string{ + "srtt", + "rttvar", + "srttInited", + } +} + +func (r *rtt) beforeSave() {} + +func (r *rtt) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.srtt) + stateSinkObject.Save(1, &r.rttvar) + stateSinkObject.Save(2, &r.srttInited) +} + +func (r *rtt) afterLoad() {} + +func (r *rtt) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.srtt) + stateSourceObject.Load(1, &r.rttvar) + stateSourceObject.Load(2, &r.srttInited) +} + +func (f *fastRecovery) StateTypeName() string { + return "pkg/tcpip/transport/tcp.fastRecovery" +} + +func (f *fastRecovery) StateFields() []string { + return []string{ + "active", + "first", + "last", + "maxCwnd", + "highRxt", + "rescueRxt", + } +} + +func (f *fastRecovery) beforeSave() {} + +func (f *fastRecovery) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.active) + stateSinkObject.Save(1, &f.first) + stateSinkObject.Save(2, &f.last) + stateSinkObject.Save(3, &f.maxCwnd) + stateSinkObject.Save(4, &f.highRxt) + stateSinkObject.Save(5, &f.rescueRxt) +} + +func (f *fastRecovery) afterLoad() {} + +func (f *fastRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.active) + stateSourceObject.Load(1, &f.first) + stateSourceObject.Load(2, &f.last) + stateSourceObject.Load(3, &f.maxCwnd) + stateSourceObject.Load(4, &f.highRxt) + stateSourceObject.Load(5, &f.rescueRxt) +} + +func (u *unixTime) StateTypeName() string { + return "pkg/tcpip/transport/tcp.unixTime" +} + +func (u *unixTime) StateFields() []string { + return []string{ + "second", + "nano", + } +} + +func (u *unixTime) beforeSave() {} + +func (u *unixTime) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.second) + stateSinkObject.Save(1, &u.nano) +} + +func (u *unixTime) afterLoad() {} + +func (u *unixTime) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.second) + stateSourceObject.Load(1, &u.nano) +} + +func (l *endpointList) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpointList" +} + +func (l *endpointList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *endpointList) beforeSave() {} + +func (l *endpointList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *endpointList) afterLoad() {} + +func (l *endpointList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *endpointEntry) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpointEntry" +} + +func (e *endpointEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *endpointEntry) beforeSave() {} + +func (e *endpointEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *endpointEntry) afterLoad() {} + +func (e *endpointEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (l *segmentList) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentList" +} + +func (l *segmentList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *segmentList) beforeSave() {} + +func (l *segmentList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *segmentList) afterLoad() {} + +func (l *segmentList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *segmentEntry) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentEntry" +} + +func (e *segmentEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *segmentEntry) beforeSave() {} + +func (e *segmentEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *segmentEntry) afterLoad() {} + +func (e *segmentEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*cubicState)(nil)) + state.Register((*SACKInfo)(nil)) + state.Register((*rcvBufAutoTuneParams)(nil)) + state.Register((*EndpointInfo)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*keepalive)(nil)) + state.Register((*rackControl)(nil)) + state.Register((*receiver)(nil)) + state.Register((*renoState)(nil)) + state.Register((*renoRecovery)(nil)) + state.Register((*sackRecovery)(nil)) + state.Register((*SACKScoreboard)(nil)) + state.Register((*segment)(nil)) + state.Register((*segmentQueue)(nil)) + state.Register((*sender)(nil)) + state.Register((*rtt)(nil)) + state.Register((*fastRecovery)(nil)) + state.Register((*unixTime)(nil)) + state.Register((*endpointList)(nil)) + state.Register((*endpointEntry)(nil)) + state.Register((*segmentList)(nil)) + state.Register((*segmentEntry)(nil)) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go deleted file mode 100644 index 0128c1f7e..000000000 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ /dev/null @@ -1,7777 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "io/ioutil" - "math" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/pkg/waiter" -) - -// endpointTester provides helper functions to test a tcpip.Endpoint. -type endpointTester struct { - ep tcpip.Endpoint -} - -// CheckReadError issues a read to the endpoint and checking for an error. -func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { - t.Helper() - res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if got != want { - t.Fatalf("ep.Read = %s, want %s", got, want) - } - if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" { - t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff) - } -} - -// CheckRead issues a read to the endpoint and checking for a success, returning -// the data read. -func (e *endpointTester) CheckRead(t *testing.T) []byte { - t.Helper() - var buf bytes.Buffer - res, err := e.ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("ep.Read = _, %s; want _, nil", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - return buf.Bytes() -} - -// CheckReadFull reads from the endpoint for exactly count bytes. -func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { - t.Helper() - var buf bytes.Buffer - w := tcpip.LimitedWriter{ - W: &buf, - N: int64(count), - } - for w.N != 0 { - _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for receive to be notified. - select { - case <-notifyRead: - case <-time.After(timeout): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } else if err != nil { - t.Fatalf("ep.Read = _, %s; want _, nil", err) - } - } - return buf.Bytes() -} - -const ( - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65535 - - // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an - // IPv4 endpoint when the MTU is set to defaultMTU in the test. - defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize -) - -func TestGiveUpConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventHUp) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) - } - } - - // Close the connection, wait for completion. - ep.Close() - - // Wait for ep to become writable. - <-notifyCh - - // Call Connect again to retreive the handshake failure status - // and stats updates. - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAborted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) - } - } - - if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } -} - -func TestConnectIncrementActiveConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) - } -} - -func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.FailedConnectionAttempts.Value() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) - } -} - -func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 - - { - err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) - } - } - - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) - } -} - -func TestCloseWithoutConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - c.EP.Close() - - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestTCPSegmentsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - // SYN and ACK - want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) - } -} - -func TestTCPResetsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - want := stats.TCP.SegmentsSent.Value() + 1 - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - // If the AckNum is not the increment of the last sequence number, a RST - // segment is sent back in response. - AckNum: c.IRS + 2, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.GetPacket() - - metricPollFn := func() error { - if got := stats.TCP.ResetsSent.Value(); got != want { - return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestTCPResetsSentNoICMP confirms that we don't get an ICMP -// DstUnreachable packet when we try send a packet which is not part -// of an active session. -func TestTCPResetsSentNoICMP(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - - // Send a SYN request for a closed port. This should elicit an RST - // but NOT an ICMPv4 DstUnreachable packet. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive whatever comes back. - b := c.GetPacket() - ipHdr := header.IPv4(b) - if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { - t.Errorf("unexpected protocol, got = %d, want = %d", got, want) - } - - // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. - sent := stats.ICMP.V4.PacketsSent - if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { - t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) - } -} - -// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates -// a RST if an ACK is received on the listening socket for which there is no -// active handshake in progress and we are not using SYN cookies. -func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Lower stackwide TIME_WAIT timeout so that the reservations - // are released instantly on Close. - tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - c.GetPacket() - - // Since an active close was done we need to wait for a little more than - // tcpLingerTimeout for the port reservations to be released and the - // socket to move to a CLOSED state. - time.Sleep(20 * time.Millisecond) - - // Now resend the same ACK, this ACK should generate a RST as there - // should be no endpoint in SYN-RCVD state and we are not using - // syn-cookies yet. The reason we send the same ACK is we need a valid - // cookie(IRS) generated by the netstack without which the ACK will be - // rejected. - c.SendPacket(nil, ackHeaders) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPResetsReceivedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } -} - -func TestTCPResetsDoNotGenerateResets(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } - c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) -} - -func TestActiveHandshake(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) -} - -func TestNonBlockingClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } -} - -func TestConnectResetAfterClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLinger to 3 seconds so that sockets are marked closed - // after 3 second in FIN_WAIT2 state. - tcpLingerTimeout := 3 * time.Second - opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint, make sure we get a FIN segment, then acknowledge - // to complete closure of sender, but don't send our own FIN. - ep.Close() - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Wait for the ep to give up waiting for a FIN. - time.Sleep(tcpLingerTimeout + 1*time.Second) - - // Now send an ACK and it should trigger a RST as the endpoint should - // not exist anymore. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - for { - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { - // This is a retransmit of the FIN, ignore it. - continue - } - - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence number - // of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - break - } -} - -// TestCurrentConnectedIncrement tests increment of the current -// established and connected counters. -func TestCurrentConnectedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) - } - gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() - if gotConnected != 1 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) - } - - ep.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) - } - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait for a little more than the TIME-WAIT duration for the socket to - // transition to CLOSED state. - time.Sleep(1200 * time.Millisecond) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -// TestClosingWithEnqueuedSegments tests handling of still enqueued segments -// when the endpoint transitions to StateClose. The in-flight segments would be -// re-enqueued to a any listening endpoint. -func TestClosingWithEnqueuedSegments(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } - - // Send a FIN for ESTABLISHED --> CLOSED-WAIT - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Get the ACK for the FIN we sent. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack a few ms to transition the endpoint out of ESTABLISHED - // state. - time.Sleep(10 * time.Millisecond) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } - - // Close the application endpoint for CLOSE_WAIT --> LAST_ACK - ep.Close() - - // Get the FIN - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Pause the endpoint`s protocolMainLoop. - ep.(interface{ StopWork() }).StopWork() - - // Enqueue last ACK followed by an ACK matching the endpoint - // - // Send Last ACK for LAST_ACK --> CLOSED - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 791, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Send a packet with ACK set, this would generate RST when - // not using SYN cookies as in this test. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 792, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Unpause endpoint`s protocolMainLoop. - ep.(interface{ ResumeWork() }).ResumeWork() - - // Wait for the protocolMainLoop to resume and update state. - time.Sleep(10 * time.Millisecond) - - // Expect the endpoint to be closed. - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - - // Check if the endpoint was moved to CLOSED and netstack a reset in - // response to the ACK packet that we sent after last-ACK. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) -} - -func TestSimpleReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when -// creating a new active TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnect(t *testing.T) { - const mtu = 5000 - - ips := []struct { - name string - createEP func(*context.Context) - connectAddr tcpip.Address - checker func(*testing.T, *context.Context, uint16, int) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - connectAddr: context.TestAddr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(true) - }, - connectAddr: context.TestV6Addr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, - }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, - }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - ip.createEP(c) - - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } - - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - - connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - { - err := c.EP.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Connect(%+v): %s", connectAddr, err) - } - } - - // Receive SYN packet with our user supplied MSS. - ip.checker(t, c, test.expMSS, ws) - }) - } - }) - } -} - -// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used -// when completing the handshake for a new TCP connection from a TCP -// listening socket. It should be present in the sent TCP SYN-ACK segment. -func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const ( - nonSynCookieAccepts = 2 - totalAccepts = 4 - mtu = 5000 - ) - - ips := []struct { - name string - createEP func(*context.Context) - sendPkt func(*context.Context, *context.Headers) - checker func(*testing.T, *context.Context, uint16, uint16) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendPacket(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(false) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendV6Packet(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) - }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, - }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, - }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - ip.createEP(c) - - // Set the SynRcvd threshold to force a syn cookie based accept to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } - - bindAddr := tcpip.FullAddress{Port: context.StackPort} - if err := c.EP.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s:", bindAddr, err) - } - - if err := c.EP.Listen(totalAccepts); err != nil { - t.Fatalf("Listen(%d): %s:", totalAccepts, err) - } - - // The first nonSynCookieAccepts packets sent will trigger a gorooutine - // based accept. The rest will trigger a cookie based accept. - for i := 0; i < totalAccepts; i++ { - // Send a SYN requests. - iss := seqnum.Value(i) - srcPort := context.TestPort + uint16(i) - ip.sendPkt(c, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - ip.checker(t, c, srcPort, test.expMSS) - } - }) - } - }) - } -} -func TestSendRstOnListenerRxSynAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -func TestSendRstOnListenerRxSynAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv4. -func TestTCPAckBeforeAcceptV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv6. -func TestTCPAckBeforeAcceptV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestSendRstOnListenerRxAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -func TestSendRstOnListenerRxAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true /* v6Only */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -// TestListenShutdown tests for the listening endpoint replying with RST -// on read shutdown. -func TestListenShutdown(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatal("Shutdown failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: 100, - AckNum: 200, - }) - - // Expect the listening endpoint to reset the connection. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) -} - -// TestListenCloseWhileConnect tests for the listening endpoint to -// drain the accept-queue when closed. This should reset all of the -// pending connections that are waiting to be accepted. -func TestListenCloseWhileConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventIn) - defer c.WQ.EventUnregister(&waitEntry) - - executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - // Wait for the new endpoint created because of handshake to be delivered - // to the listening endpoint's accept queue. - <-notifyCh - - // Close the listening endpoint. - c.EP.Close() - - // Expect the listening endpoint to reset the connection. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) -} - -func TestTOSV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { - t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) - } - - v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) - } - - if v != tos { - t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) - } - - testV4Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestTrafficClassV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { - t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) - } - - v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) - } - - if v != tos { - t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) - } - - // Test the connection request. - testV6Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetV6Packet() - checker.IPv6(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestConnectBindToDevice(t *testing.T) { - for _, test := range []struct { - name string - device tcpip.NICID - want tcp.EndpointState - }{ - {"RightDevice", 1, tcp.StateEstablished}, - {"WrongDevice", 2, tcp.StateSynSent}, - {"AnyDevice", 0, tcp.StateEstablished}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) - } - // Start connection attempt. - waitEntry, _ := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - - c.GetPacket() - if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - }) - } -} - -func TestSynSent(t *testing.T) { - for _, test := range []struct { - name string - reset bool - }{ - {"RstOnSynSent", true}, - {"CloseOnSynSent", false}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create an endpoint, don't handshake because we want to interfere with the - // handshake process. - c.Create(-1) - - // Start connection attempt. - waitEntry, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - err := c.EP.Connect(addr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - if test.reset { - // Send a packet with a proper ACK and a RST flag to cause the socket - // to error and close out. - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagRst | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - } else { - c.EP.Close() - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatal("timed out waiting for packet to arrive") - } - - ept := endpointTester{c.EP} - if test.reset { - ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) - } else { - ept.CheckReadError(t, &tcpip.ErrAborted{}) - } - - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } - - // Due to the RST the endpoint should be in an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - }) - } -} - -func TestOutOfOrderReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send second half of data first, with seqnum 3 ahead of expected. - data := []byte{1, 2, 3, 4, 5, 6} - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that we get an ACK specifying which seqnum is expected. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait 200ms and check that no data has been received. - time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send the first 3 bytes now. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive data. - read := ept.CheckReadFull(t, 6, ch, 5*time.Second) - - // Check that we received the data in proper order. - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that the whole data is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestOutOfOrderFlood(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rcvBufSz := math.MaxUint16 - c.CreateConnected(789, 30000, rcvBufSz) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send 100 packets before the actual one that is expected. - data := []byte{1, 2, 3, 4, 5, 6} - for i := 0; i < 100; i++ { - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 796, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Send packet with seqnum 793. It must be discarded because the - // out-of-order buffer was filled by the previous packets. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now send the expected packet, seqnum 790. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that only packet 790 is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(793), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestRstOnCloseWithUnreadData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now that we know we have unread data, let's just close the connection - // and verify that netstack sends an RST rather than a FIN. - c.EP.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // This final ACK should be ignored because an ACK on a reset doesn't mean - // anything. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) - - // Make sure we get the FIN but DON't ACK IT. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Cause a RST to be generated by closing the read end now since we have - // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) - - // Make sure we get the RST - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence - // number of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // The ACK to the FIN should now be rejected since the connection has been - // closed by a RST. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestShutdownRead(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) - } -} - -func TestFullWindowReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBufSz = 10 - c.CreateConnected(789, 30000, rcvBufSz) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies - // the provided buffer value by tcp.SegOverheadFactor to calculate the actual - // receive buffer size. - data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) - for i := range data { - data[i] = byte(i % 255) - } - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and window goes to zero. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) - - // Receive data and check it. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) - } - - // Check that we get an ACK for the newly non-zero window. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(10), - ), - ) -} - -// Test the stack receive window advertisement on receiving segments smaller than -// segment overhead. It tests for the right edge of the window to not grow when -// the endpoint is not being read from. -func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize, - Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), - } - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - - c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Bump up the receive buffer size such that, when the receive window grows, - // the scaled window exceeds maxUint16. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) - } - - // Keep the payload size < segment overhead and such that it is a multiple - // of the window scaled value. This enables the test to perform equality - // checks on the incoming receive window. - payloadSize := 1 << c.RcvdWindowScale - if payloadSize >= tcp.SegSize { - t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize) - } - payload := generateRandomPayload(t, payloadSize) - payloadLen := seqnum.Size(len(payload)) - iss := seqnum.Value(789) - seqNum := iss.Add(1) - - // Send payload to the endpoint and return the advertised receive window - // from the endpoint. - getIncomingRcvWnd := func() uint32 { - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: seqNum, - AckNum: c.IRS.Add(1), - Flags: header.TCPFlagAck, - RcvWnd: 30000, - }) - seqNum = seqNum.Add(payloadLen) - - pkt := c.GetPacket() - return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale - } - - // Read the advertised receive window with the ACK for payload. - rcvWnd := getIncomingRcvWnd() - - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } - - // Read the data so that the subsequent ACK from the endpoint - // grows the right edge of the window. - var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { - t.Fatalf("c.EP.Read: %s", err) - } - - // Check if we have received max uint16 as our advertised - // scaled window now after a read above. - maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) - if got, want := getIncomingRcvWnd(), maxRcv; got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } - - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } -} - -func TestNoWindowShrinking(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Start off with a certain receive buffer then cut it in half and verify that - // the right edge of the window does not shrink. - // NOTE: Netstack doubles the value specified here. - rcvBufSize := 65536 - iss := seqnum.Value(789) - // Enable window scaling with a scale of zero from our end. - c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send a 1 byte payload so that we can record the current receive window. - // Send a payload of half the size of rcvBufSize. - seqNum := iss.Add(1) - payload := []byte{1} - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqNum, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Read the 1 byte payload we just sent. - if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { - t.Fatalf("got data: %v, want: %v", got, want) - } - - seqNum = seqNum.Add(1) - // Verify that the ACK does not shrink the window. - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Stash the initial window. - initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) - // Now shrink the receive buffer to half its original size. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } - - data := generateRandomPayload(t, rcvBufSize) - // Send a payload of half the size of rcvBufSize. - c.SendPacket(data[:rcvBufSize/2], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqNum, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) - - // Verify that the ACK does not shrink the window. - pkt = c.GetPacket() - checker.IPv4(t, pkt, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) - if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { - t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) - } - - // Send another payload of half the size of rcvBufSize. This should fill up the - // socket receive buffer and we should see a zero window. - c.SendPacket(data[rcvBufSize/2:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqNum, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) - - // Receive data and check it. - read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that we get an ACK for the newly non-zero window, which is the new - // receive buffer size we set after the connection was established. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), - ), - ) -} - -func TestSimpleSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestZeroWindowSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check if we got a zero-window probe. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Open up the window. Data should be received now. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(789, 30000, 65535*3) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 - c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN - // should not carry the window scaling option. - c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestZeroScaledWindowReceive(t *testing.T) { - // This test ensures that the endpoint sends a non-zero window size - // advertisement when the scaled window transitions from 0 to non-zero, - // but the actual window (not scaled) hasn't gotten to zero. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the buffer size such that a window scale of 5 will be used. - const bufSz = 65535 * 10 - const ws = uint32(5) - c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - // Write chunks of 50000 bytes. - remain := 0 - sent := 0 - data := make([]byte, 50000) - // Keep writing till the window drops below len(data). - for { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Don't reduce window to zero here. - if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { - remain = wnd << ws - break - } - } - - // Make the window non-zero, but the scaled window zero. - for remain >= 16 { - data = data[:remain-15] - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Since the receive buffer is split between window advertisement and - // application data buffer the window does not always reflect the space - // available and actual space available can be a bit more than what is - // advertised in the window. - wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) - if wnd == 0 { - break - } - remain = wnd << ws - } - - // Read at least 2MSS of data. An ack should be sent in response to that. - // Since buffer space is now split in half between window and application - // data we need to read more than 1 MSS(65536) of data for a non-zero window - // update to be sent. For 1MSS worth of window to be available we need to - // read at least 128KB. Since our segments above were 50KB each it means - // we need to read at 3 packets. - w := tcpip.LimitedWriter{ - W: ioutil.Discard, - N: defaultMTU * 2, - } - for w.N != 0 { - res, err := c.EP.Read(&w, tcpip.ReadOptions{}) - t.Logf("err=%v res=%#v", err, res) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), - checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestSegmentMerging(t *testing.T) { - tests := []struct { - name string - stop func(tcpip.Endpoint) - resume func(tcpip.Endpoint) - }{ - { - "stop work", - func(ep tcpip.Endpoint) { - ep.(interface{ StopWork() }).StopWork() - }, - func(ep tcpip.Endpoint) { - ep.(interface{ ResumeWork() }).ResumeWork() - }, - }, - { - "cork", - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(true) - }, - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(false) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send tcp.InitialCwnd number of segments to fill up - // InitialWindow but don't ACK. That should prevent - // anymore packets from going out. - var r bytes.Reader - for i := 0; i < tcp.InitialCwnd; i++ { - r.Reset([]byte{0}) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - // Now send the segments that should get merged as the congestion - // window is full and we won't be able to send any more packets. - var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - // Check that we get tcp.InitialCwnd packets. - for i := 0; i < tcp.InitialCwnd; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize+1), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. - RcvWnd: 30000, - }) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+11), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), - RcvWnd: 30000, - }) - }) - } -} - -func TestDelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - c.EP.SocketOptions().SetDelayOption(true) - - var allData []byte - for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - for _, want := range [][]byte{allData[:1], allData[1:]} { - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(want)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { - t.Fatalf("got data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(want))) - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - } -} - -func TestUndelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - c.EP.SocketOptions().SetDelayOption(true) - - allData := [][]byte{{0}, {1, 2, 3}} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - - // Check that data is received. - first := c.GetPacket() - checker.IPv4(t, first, - checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { - t.Fatalf("got first packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[0]))) - - // Check that we don't get the second packet yet. - c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - - c.EP.SocketOptions().SetDelayOption(false) - - // Check that data is received. - second := c.GetPacket() - checker.IPv4(t, second, - checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { - t.Fatalf("got second packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[1]))) - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) -} - -func TestMSSNotDelayed(t *testing.T) { - tests := []struct { - name string - fn func(tcpip.Endpoint) - }{ - {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const maxPayload = 100 - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - test.fn(c.EP) - - allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - - for i, data := range allData { - // Check that data is received. - packet := c.GetPacket() - checker.IPv4(t, packet, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { - t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) - } - - seq = seq.Add(seqnum.Size(len(data))) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - }) - } -} - -func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { - payloadMultiplier := 10 - dataLen := payloadMultiplier * maxPayload - data := make([]byte, dataLen) - for i := range data { - data[i] = byte(i) - } - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received in chunks. - bytesReceived := 0 - numPackets := 0 - for bytesReceived != dataLen { - b := c.GetPacket() - numPackets++ - tcpHdr := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcpHdr.Payload()) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { - t.Fatalf("got data = %v, want = %v", p, pdata) - } - bytesReceived += payloadLen - var options []byte - if c.TimeStampEnabled { - // If timestamp option is enabled, echo back the timestamp and increment - // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcpHdr.ParsedOptions() - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) - options = tsOpt[:] - } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options, - }) - } - if numPackets == 1 { - t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") - } -} - -func TestSendGreaterThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSetTTL(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := context.New(t, 65535) - defer c.Cleanup() - - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { - t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) - } - - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) - } - } - - // Receive SYN packet. - b := c.GetPacket() - - checker.IPv4(t, b, checker.TTL(wantTTL)) - }) - } -} - -func TestActiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, 65535) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestPassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 536 - const mtu = 2000 - c := context.New(t, mtu) - defer c.Cleanup() - - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(0) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestForwarderSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %s", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynOptionsOnActiveConnect(t *testing.T) { - const mtu = 1400 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 3 - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } - - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) - } - } - - // Receive SYN packet. - b := c.GetPacket() - mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Wait for retransmit. - time.Sleep(1 * time.Second) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcpHdr.SourcePort()), - checker.TCPSeqNum(tcpHdr.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - // Send SYN-ACK. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestCloseListener(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Close the listener and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } -} - -func TestReceiveOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Try to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - -loop: - for { - switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { - case *tcpip.ErrWouldBlock: - select { - case <-ch: - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) - } - break loop - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") - } - case *tcpip.ErrConnectionReset: - break loop - default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) - } - } - - if tcp.EndpointState(c.EP.State()) != tcp.StateError { - t.Fatalf("got EP state is not StateError") - } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestSendOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Wait for the RST to be received. - time.Sleep(1 * time.Second) - - // Try to write. - var r bytes.Reader - r.Reset(make([]byte, 10)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) - } -} - -// TestMaxRetransmitsTimeout tests if the connection is timed out after -// a segment has been retransmitted MaxRetries times. -func TestMaxRetransmitsTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const numRetries = 2 - opt := tcpip.TCPMaxRetriesOption(numRetries) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Expect first transmit and MaxRetries retransmits. - for i := 0; i < numRetries+1; i++ { - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - } - // Wait for the connection to timeout after MaxRetries retransmits. - initRTO := 1 * time.Second - select { - case <-notifyCh: - case <-time.After((2 << numRetries) * initRTO): - t.Fatalf("connection still alive after maximum retransmits.\n") - } - - // Send an ACK and expect a RST as the connection would have been closed. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -// TestMaxRTO tests if the retransmit interval caps to MaxRTO. -func TestMaxRTO(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rto := 1 * time.Second - opt := tcpip.TCPMaxRTOOption(rto) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - const numRetransmits = 2 - for i := 0; i < numRetransmits; i++ { - start := time.Now() - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { - t.Errorf("Retransmit interval not capped to MaxRTO.\n") - } - } -} - -// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is -// unique on retransmits. -func TestRetransmitIPv4IDUniqueness(t *testing.T) { - for _, tc := range []struct { - name string - size int - }{ - {"1Byte", 1}, - {"512Bytes", 512}, - } { - t.Run(tc.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - // Disabling PMTU discovery causes all packets sent from this socket to - // have DF=0. This needs to be done because the IPv4 ID uniqueness - // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 - // Section 4, and datagrams with DF=0 are non-atomic. - if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { - t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) - } - - var r bytes.Reader - r.Reset(make([]byte, tc.size)) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.FragmentFlags(0), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} - // Expect two retransmitted packets, and that all packets received have - // unique IPv4 ID values. - for i := 0; i <= 2; i++ { - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.FragmentFlags(0), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - id := header.IPv4(pkt).ID() - if _, exists := idSet[id]; exists { - t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) - } - idSet[id] = struct{}{} - } - }) - } -} - -func TestFinImmediately(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Don't acknowledge yet. We should get a retransmit of the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithNoPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and have it acknowledged. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Shutdown, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingDataCwndFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write enough segments to fill the congestion window before ACK'ing - // any of them. - view := make([]byte, 10) - var r bytes.Reader - for i := tcp.InitialCwnd; i > 0; i-- { - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - } - - next := uint32(c.IRS) + 1 - for i := tcp.InitialCwnd; i > 0; i-- { - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - } - - // Shutdown the connection, check that the FIN segment isn't sent - // because the congestion window doesn't allow it. Wait until a - // retransmit is received. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Send the ACK that will allow the FIN to be sent as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPartialAck(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. Also send - // FIN from the test side. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that we get an ACK for the fin. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(791), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send an ACK for the data, but not for the FIN yet. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 791, - AckNum: seqnum.Value(next - 1), - RcvWnd: 30000, - }) - - // Check that we don't get a retransmit of the FIN. - c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) - - // Ack the FIN. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 791, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) -} - -func TestUpdateListenBacklog(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Update the backlog with another Listen() on the same endpoint. - if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %s", err) - } - - ep.Close() -} - -func scaledSendWindow(t *testing.T, scale uint8) { - // This test ensures that the endpoint is using the right scaling by - // sending a buffer that is larger than the window size, and ensuring - // that the endpoint doesn't send more than allowed. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - header.TCPOptionWS, 3, scale, header.TCPOptionNOP, - }) - - // Open up the window with a scaled value. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 1, - }) - - // Send some data. Check that it's capped by the window size. - view := make([]byte, 65535) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that only data that fits in the scaled window is sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen((1<<scale)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Reset the connection to free resources. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - }) -} - -func TestScaledSendWindow(t *testing.T) { - for scale := uint8(0); scale <= 14; scale++ { - scaledSendWindow(t, scale) - } -} - -func TestReceivedValidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ValidSegmentsReceived.Value() + 1 - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) - } - // Ensure there were no errors during handshake. If these stats have - // incremented, then the connection should not have been established. - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) - } -} - -func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.InvalidSegmentsReceived.Value() + 1 - vv := c.BuildSegment(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] - tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 - - c.SendSegment(vv) - - if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } -} - -func TestReceivedIncorrectChecksumIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ChecksumErrors.Value() + 1 - vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] - // Overwrite a byte in the payload which should cause checksum - // verification to fail. - tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 - - c.SendSegment(vv) - - if got := stats.TCP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) - } -} - -func TestReceivedSegmentQueuing(t *testing.T) { - // This test sends 200 segments containing a few bytes each to an - // endpoint and checks that they're all received and acknowledged by - // the endpoint, that is, that none of the segments are dropped by - // internal queues. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send 200 segments. - data := []byte{1, 2, 3} - for i := 0; i < 200; i++ { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + i*len(data)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - } - - // Receive ACKs for all segments. - last := seqnum.Value(790 + 200*len(data)) - for { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - tcpHdr := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcpHdr.AckNumber()) - if ack == last { - break - } - - if last.LessThan(ack) { - t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) - } - } -} - -func TestReadAfterClosedState(t *testing.T) { - // This test ensures that calling Read() or Peek() after the endpoint - // has transitioned to closedState still works if there is pending - // data. To transition to stateClosed without calling Close(), we must - // shutdown the send path and the peer must send its own FIN. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Shutdown immediately for write, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Send some data and acknowledge the FIN. - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(791+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack the chance to transition to closed state from - // TIME_WAIT. - time.Sleep(tcpTimeWaitTimeout * 2) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that peek works. - var peekBuf bytes.Buffer - res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) - if err != nil { - t.Fatalf("Peek failed: %s", err) - } - - if got, want := res.Count, len(data); got != want { - t.Fatalf("res.Count = %d, want %d", got, want) - } - if !bytes.Equal(data, peekBuf.Bytes()) { - t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) - } - - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Now that we drained the queue, check that functions fail with the - // right error code. - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var buf bytes.Buffer - { - _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) - } - } -} - -func TestReusePort(t *testing.T) { - // This test ensures that ports are immediately available for reuse - // after Close on the endpoints using them returns. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // First case, just an endpoint that was bound. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() - - // Second case, an endpoint that was bound and is connecting.. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) - } - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() - - // Third case, an endpoint that was bound and is listening. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } -} - -func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) - } - - if int(s) != v { - t.Fatalf("got receive buffer size = %d, want = %d", s, v) - } -} - -func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v { - t.Fatalf("got send buffer size = %d, want = %d", s, v) - } -} - -func TestDefaultBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer func() { - if ep != nil { - ep.Close() - } - }() - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default send buffer size. - { - opt := tcpip.TCPSendBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default receive buffer size. - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) -} - -func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - // Change the min/max values for send/receive - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - { - opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Set values below the min/2. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) - } - - checkRecvBufferSize(t, ep, 200) - - ep.SocketOptions().SetSendBufferSize(149, true) - - checkSendBufferSize(t, ep, 300) - - // Set values above the max. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - - // Values above max are capped at max and then doubled. - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - - ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) - - // Values above max are capped at max and then doubled. - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) - - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - if err := s.CreateNIC(321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %s", err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError tcpip.Error - getBindToDevice int32 - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := int32(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) - } - } - bindToDevice := ep.SocketOptions().GetBindToDevice() - if bindToDevice != testAction.getBindToDevice { - t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) - } - }) - } -} - -func makeStack() (*stack.Stack, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - id := loopback.New() - if testing.Verbose() { - id = sniffer.New(id) - } - - if err := s.CreateNIC(1, id); err != nil { - return nil, err - } - - for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address - }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, - } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { - return nil, err - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return s, nil -} - -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) - } - } - - <-notifyCh - if err := ep.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) - } - - // Write something. - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Read back what was written. - wq.EventUnregister(&waitEntry) - wq.EventRegister(&waitEntry, waiter.EventIn) - ept := endpointTester{ep} - rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) - - if !bytes.Equal(data, rd) { - t.Fatalf("got data = %v, want = %v", rd, data) - } -} - -func TestConnectAvoidsBoundPorts(t *testing.T) { - addressTypes := func(t *testing.T, network string) []string { - switch network { - case "ipv4": - return []string{"v4"} - case "ipv6": - return []string{"v6"} - case "dual": - return []string{"v6", "mapped"} - default: - t.Fatalf("unknown network: '%s'", network) - } - - panic("unreachable") - } - - address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { - switch addressType { - case "v4": - if isAny { - return "" - } - return context.StackAddr - case "v6": - if isAny { - return "" - } - return context.StackV6Addr - case "mapped": - if isAny { - return context.V4MappedWildcardAddr - } - return context.StackV4MappedAddr - default: - t.Fatalf("unknown address type: '%s'", addressType) - } - - panic("unreachable") - } - // This test ensures that Endpoint.Connect doesn't select already-bound ports. - networks := []string{"ipv4", "ipv6", "dual"} - for _, exhaustedNetwork := range networks { - t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { - for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { - t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { - for _, isAny := range []bool{false, true} { - t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { - for _, candidateNetwork := range networks { - t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { - for _, candidateAddressType := range addressTypes(t, candidateNetwork) { - t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - var eps []tcpip.Endpoint - defer func() { - for _, ep := range eps { - ep.Close() - } - }() - makeEP := func(network string) tcpip.Endpoint { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch network { - case "ipv4": - networkProtocolNumber = ipv4.ProtocolNumber - case "ipv6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatalf("unknown network: '%s'", network) - } - ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - eps = append(eps, ep) - switch network { - case "ipv4": - case "ipv6": - ep.SocketOptions().SetV6Only(true) - case "dual": - ep.SocketOptions().SetV6Only(false) - default: - t.Fatalf("unknown network: '%s'", network) - } - return ep - } - - var v4reserved, v6reserved bool - switch exhaustedAddressType { - case "v4", "mapped": - v4reserved = true - case "v6": - v6reserved = true - // Dual stack sockets bound to v6 any reserve on v4 as - // well. - if isAny { - switch exhaustedNetwork { - case "ipv6": - case "dual": - v4reserved = true - default: - t.Fatalf("unknown address type: '%s'", exhaustedNetwork) - } - } - default: - t.Fatalf("unknown address type: '%s'", exhaustedAddressType) - } - var collides bool - switch candidateAddressType { - case "v4", "mapped": - collides = v4reserved - case "v6": - collides = v6reserved - default: - t.Fatalf("unknown address type: '%s'", candidateAddressType) - } - - for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { - if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %s", i, err) - } - } - var want tcpip.Error = &tcpip.ErrConnectStarted{} - if collides { - want = &tcpip.ErrNoPortAvailable{} - } - if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) - } - }) - } - }) - } - }) - } - }) - } - }) - } -} - -func TestPathMTUDiscovery(t *testing.T) { - // This test verifies the stack retransmits packets after it receives an - // ICMP packet indicating that the path MTU has been exceeded. - c := context.New(t, 1500) - defer c.Cleanup() - - // Create new connection with MSS of 1460. - const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - // Send 3200 bytes of data. - const writeSize = 3200 - data := make([]byte, writeSize) - for i := range data { - data[i] = byte(i) - } - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { - var ret []byte - for i, size := range sizes { - p := c.GetPacket() - if i == which { - ret = p - } - checker.IPv4(t, p, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(seqNum), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - seqNum += uint32(size) - } - return ret - } - - // Receive three packets. - sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} - first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) - - // Send "packet too big" messages back to netstack. - const newMTU = 1200 - const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize - mtu := []byte{0, 0, newMTU / 256, newMTU % 256} - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) - - // See retransmitted packets. None exceeding the new max. - sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} - receivePackets(c, sizes, -1, uint32(c.IRS)+1) -} - -func TestTCPEndpointProbe(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - invoked := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint ID is what we expect. - // - // We don't do an extensive validation of every field but a - // basic sanity test. - if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("got LocalAddress: %q, want: %q", got, want) - } - if got, want := state.ID.LocalPort, c.Port; got != want { - t.Fatalf("got LocalPort: %d, want: %d", got, want) - } - if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("got RemoteAddress: %q, want: %q", got, want) - } - if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("got RemotePort: %d, want: %d", got, want) - } - - invoked <- struct{}{} - }) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-invoked: - case <-time.After(100 * time.Millisecond): - t.Fatalf("TCP Probe function was not called") - } -} - -func TestStackSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - var oldCC tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - - got, want := cc, oldCC - // If SetTransportProtocolOption is expected to succeed - // then the returned value for congestion control should - // match the one specified in the - // SetTransportProtocolOption call above, else it should - // be what it was before the call to - // SetTransportProtocolOption. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } -} - -func TestStackAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Query permitted congestion control algorithms. - var aCC tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) - } - if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestStackSetAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.TCPAvailableCongestionControlOption("xyz") - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) - } - - // Verify that we still get the expected list of congestion control options. - var cc tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) - } - if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) - } -} - -func TestEndpointSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } - - for _, connected := range []bool{false, true} { - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - var oldCC tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) - } - - if connected { - c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) - } - - if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { - t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) - } - - got, want := cc, oldCC - // If SetSockOpt is expected to succeed then the - // returned value for congestion control should match - // the one specified in the SetSockOpt above, else it - // should be what it was before the call to SetSockOpt. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control = %+v, want = %+v", got, want) - } - }) - } - } -} - -func enableCUBIC(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) - } -} - -func TestKeepalive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) - } - keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) - } - c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) - - // 5 unacked keepalives are sent. ACK each one, and check that the - // connection stays alive after 5. - for i := 0; i < 10; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Acknowledge the keepalive. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS, - RcvWnd: 30000, - }) - } - - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send some data and wait before ACKing it. Keepalives should be disabled - // during this period. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Wait for the packet to be retransmitted. Verify that no keepalives - // were sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - - next += uint32(len(view)) - - // Send ACK. Keepalives should start sending again. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Now receive 5 keepalives, but don't ACK them. The connection - // should be reset after 5. - for i := 0; i < 5; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next-1)), - checker.TCPAckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) - - // The connection should be terminated after 5 unacked keepalives. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } - - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - t.Helper() - // Send a SYN request. - irs = seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptions(), - })) - } - - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - t.Helper() - // Send a SYN request. - irs = seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptionsV6(), - })) - } - - checker.IPv6(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -// TestListenBacklogFull tests that netstack does not complete handshakes if the -// listen backlog for the endpoint is full. -func TestListenBacklogFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 10 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - } - - time.Sleep(50 * time.Millisecond) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(lastPortOffset), - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } - - // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - - newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } -} - -// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a -// non unicast IPv4 address are not accepted. -func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpip.Address("\xe0\x00\x01\x02") - otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") - subnet := context.StackAddrWithPrefix.Subnet() - subnetBroadcastAddr := subnet.Broadcast() - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - name: "SourceUnspecified", - srcAddr: header.IPv4Any, - dstAddr: context.StackAddr, - }, - { - name: "SourceBroadcast", - srcAddr: header.IPv4Broadcast, - dstAddr: context.StackAddr, - }, - { - name: "SourceOurMulticast", - srcAddr: multicastAddr, - dstAddr: context.StackAddr, - }, - { - name: "SourceOtherMulticast", - srcAddr: otherMulticastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestUnspecified", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Any, - }, - { - name: "DestBroadcast", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Broadcast, - }, - { - name: "DestOurMulticast", - srcAddr: context.TestAddr, - dstAddr: multicastAddr, - }, - { - name: "DestOtherMulticast", - srcAddr: context.TestAddr, - dstAddr: otherMulticastAddr, - }, - { - name: "SrcSubnetBroadcast", - srcAddr: subnetBroadcastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestSubnetBroadcast", - srcAddr: context.TestAddr, - dstAddr: subnetBroadcastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(789) - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestAddr, context.StackAddr) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } -} - -// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a -// non unicast IPv6 address are not accepted. -func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv6Any, - context.StackV6Addr, - }, - { - "SourceAllNodes", - header.IPv6AllNodesMulticastAddress, - context.StackV6Addr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackV6Addr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackV6Addr, - }, - { - "DestUnspecified", - context.TestV6Addr, - header.IPv6Any, - }, - { - "DestAllNodes", - context.TestV6Addr, - header.IPv6AllNodesMulticastAddress, - }, - { - "DestOurMulticast", - context.TestV6Addr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestV6Addr, - otherMulticastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(789) - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestV6Addr, context.StackV6Addr) - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } -} - -func TestListenSynRcvdQueueFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send two SYN's the first one should get a SYN-ACK, the - // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } -} - -func TestListenBacklogFullSynCookieInUse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - opt := tcpip.TCPSynRcvdCountThresholdOption(1) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - executeHandshake(t, c, context.TestPort, false) - - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - // pick a different src port for new SYN. - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - // The Syn should be dropped as the endpoint's backlog is full. - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Verify that there is only one acceptable connection at this point. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } -} - -func TestSYNRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send the same SYN packet multiple times. We should still get a valid SYN-ACK - // reply. - irs := seqnum.Value(789) - for i := 0; i < 5; i++ { - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Receive the SYN-ACK reply. - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...)) -} - -func TestSynRcvdBadSeqNumber(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcpHdr.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now send a packet with an out-of-window sequence number - largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: largeSeqnum, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Should receive an ACK with the expected SEQ number - b = c.GetPacket() - tcpCheckers = []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPAckNum(uint32(irs) + 1), - checker.TCPSeqNum(uint32(iss + 1)), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now that the socket replied appropriately with the ACK, - // complete the connection to test that the large SEQ num - // did not change the state from SYN-RCVD. - - // Send ACK to move to ESTABLISHED state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - newEP, _, err := c.EP.Accept(nil) - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: - t.Fatalf("Accept failed: %s", err) - } - - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) - if string(tcpHdr.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) - } -} - -func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - - srcPort := uint16(context.TestPort) - executeHandshake(t, c, srcPort+1, false) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) - } -} - -func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - srcPort := uint16(context.TestPort) - // Now attempt a handshakes it will fill up the accept backlog. - executeHandshake(t, c, srcPort, false) - - // Give time for the final ACK to be processed as otherwise the next handshake could - // get accepted before the previous one based on goroutine scheduling. - time.Sleep(50 * time.Millisecond) - - want := stats.TCP.ListenOverflowSynDrop.Value() + 1 - - // Now we will send one more SYN and this one should get dropped - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestEndpointBindListenAcceptState(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - ept := endpointTester{ep} - ept.CheckReadError(t, &tcpip.ErrNotConnected{}) - if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { - t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - aep, _, err := ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - aep, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - { - err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) - } - } - // Listening endpoint remains in listen state. - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - ep.Close() - // Give worker goroutines time to receive the close notification. - time.Sleep(1 * time.Second) - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - // Accepted endpoint remains open when the listen endpoint is closed. - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - -} - -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. -func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size defined above. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - // NOTE: The timestamp values in the sent packets are meaningless to the - // peer so we just increment the timestamp value by 1 every batch as we - // are not really using them for anything. Send a single byte to verify - // the advertised window. - tsVal := rawEP.TSVal + 1 - - // Introduce a 25ms latency by delaying the first byte. - latency := 25 * time.Millisecond - time.Sleep(latency) - // Send an initial payload with atleast segment overhead size. The receive - // window would not grow for smaller segments. - rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal) - - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() - - time.Sleep(25 * time.Millisecond) - - // Allocate a large enough payload for the test. - payloadSize := receiveBufferSize * 2 - b := make([]byte, int(payloadSize)) - - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := 0 - end := payloadSize / 2 - packetsSent := 0 - for ; start < end; start += mss { - packetEnd := start + mss - if start+mss > end { - packetEnd = end - } - rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) - packetsSent++ - } - - // Resume the worker so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Since we sent almost the full receive buffer worth of data (some may have - // been dropped due to segment overheads), we should get a zero window back. - pkt = c.GetPacket() - tcpHdr := header.TCP(header.IPv4(pkt).Payload()) - gotRcvWnd := tcpHdr.WindowSize() - wantAckNum := tcpHdr.AckNumber() - if got, want := int(gotRcvWnd), 0; got != want { - t.Fatalf("got rcvWnd: %d, want: %d", got, want) - } - - time.Sleep(25 * time.Millisecond) - // Verify that sending more data when receiveBuffer is exhausted. - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - - // Now read all the data from the endpoint and verify that advertised - // window increases to the full available buffer size. - for { - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - break - } - } - - // Verify that we receive a non-zero window update ACK. When running - // under thread santizer this test can end up sending more than 1 - // ack, 1 for the non-zero window - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(uint32(wantAckNum)), - func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - // We use 10% here as the error margin upwards as the initial window we - // got was afer 1 segment was already in the receive buffer queue. - tolerance := 1.1 - if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { - t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) - } - }, - )) -} - -// This test verifies that the advertised window is auto-tuned up as the -// application is reading the data that is being received. -func TestReceiveBufferAutoTuning(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - // Enable Auto-tuning. - stk := c.Stack() - // Disable out of window rate limiting for this test by setting it to 0 as we - // use out of window ACKs to measure the advertised window. - var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption - if err := stk.SetOption(tcpInvalidRateLimit); err != nil { - t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) - } - - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size used by stack. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := uint32(rawEP.TSVal) - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale - scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) - } - // Allocate a large array to send to the endpoint. - b := make([]byte, receiveBufferSize*48) - - // In every iteration we will send double the number of bytes sent in - // the previous iteration and read the same from the app. The received - // window should grow by at least 2x of bytes read by the app in every - // RTT. - offset := 0 - payloadSize := receiveBufferSize / 8 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - latency := 1 * time.Millisecond - for i := 0; i < 5; i++ { - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := offset - end := offset + payloadSize - totalSent := 0 - packetsSent := 0 - for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - totalSent += mss - packetsSent++ - } - - // Resume it so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Give 1ms for the worker to process the packets. - time.Sleep(1 * time.Millisecond) - - lastACK := c.GetPacket() - // Discard any intermediate ACKs and only check the last ACK we get in a - // short time period of few ms. - for { - time.Sleep(1 * time.Millisecond) - pkt := c.GetPacketNonBlocking() - if pkt == nil { - break - } - lastACK = pkt - } - if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { - t.Fatalf("advertised window got: %d, want <= %d", got, want) - } - - // Now read all the data from the endpoint and invoke the - // moderation API to allow for receive buffer auto-tuning - // to happen before we measure the new window. - totalCopied := 0 - for { - res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - break - } - totalCopied += res.Count - } - - // Invoke the moderation API. This is required for auto-tuning - // to happen. This method is normally expected to be invoked - // from a higher layer than tcpip.Endpoint. So we simulate - // copying to userspace by invoking it explicitly here. - c.EP.ModerateRecvBuf(totalCopied) - - // Now send a keep-alive packet to trigger an ACK so that we can - // measure the new window. - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - - if i == 0 { - // In the first iteration the receiver based RTT is not - // yet known as a result the moderation code should not - // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) - } else { - // Read loop above could generate an ACK if the window had dropped to - // zero and then read had opened it up. - lastACK := c.GetPacket() - // Discard any intermediate ACKs and only check the last ACK we get in a - // short time period of few ms. - for { - time.Sleep(1 * time.Millisecond) - pkt := c.GetPacketNonBlocking() - if pkt == nil { - break - } - lastACK = pkt - } - curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale - // If thew new current window is close maxReceiveBufferSize then terminate - // the loop. This can happen before all iterations are done due to timing - // differences when running the test. - if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { - break - } - // Increase the latency after first two iterations to - // establish a low RTT value in the receiver since it - // only tracks the lowest value. This ensures that when - // ModerateRcvBuf is called the elapsed time is always > - // rtt. Without this the test is flaky due to delays due - // to scheduling/wakeup etc. - latency += 50 * time.Millisecond - } - time.Sleep(latency) - offset += payloadSize - payloadSize *= 2 - } - // Check that at the end of our iterations the receive window grew close to the maximum - // permissible size of maxReceiveBufferSize/2 - if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { - t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) - } - -} - -func TestDelayEnabled(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - checkDelayOption(t, c, false, false) // Delay is disabled by default. - - for _, v := range []struct { - delayEnabled tcpip.TCPDelayEnabled - wantDelayOption bool - }{ - {delayEnabled: false, wantDelayOption: false}, - {delayEnabled: true, wantDelayOption: true}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) - } -} - -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { - t.Helper() - - var gotDelayEnabled tcpip.TCPDelayEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { - t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) - } - if gotDelayEnabled != wantDelayEnabled { - t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) - } - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) - if err != nil { - t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) - } - gotDelayOption := ep.SocketOptions().GetDelayOption() - if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) - } -} - -func TestTCPLingerTimeout(t *testing.T) { - c := context.New(t, 1500 /* mtu */) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - testCases := []struct { - name string - tcpLingerTimeout time.Duration - want time.Duration - }{ - {"NegativeLingerTimeout", -123123, -1}, - // Zero is treated same as the stack's default TCP_LINGER2 timeout. - {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, - {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, - // Values > stack's TCPLingerTimeout are capped to the stack's - // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) - } - - v = 0 - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(&%T) = %s", v, err) - } - if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("got linger timeout = %s, want = %s", got, want) - } - }) - } -} - -func TestTCPTimeWaitRSTIgnored(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now send a RST and this should be ignored and not - // generate an ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - }) - - c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitOutOfOrder(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitNewSyn(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Send a SYN request w/ sequence number lower than - // the highest sequence number sent. We just reuse - // the same number. - iss = seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) - - // drain any older notifications from the notification channel before attempting - // 2nd connection. - select { - case <-ch: - default: - } - - // Send a SYN request w/ sequence number higher than - // the highest sequence number sent. - iss = seqnum.Value(792) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } - - want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - time.Sleep(2 * time.Second) - - // Now send a duplicate FIN. This should cause the TIME_WAIT to extend - // by another 5 seconds and also send us a duplicate ACK as it should - // indicate that the final ACK was potentially lost. - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Sleep for 4 seconds so at this point we are 1 second past the - // original tcpLingerTimeout of 5 seconds. - time.Sleep(4 * time.Second) - - // Send an ACK and it should not generate any packet as the socket - // should still be in TIME_WAIT for another another 5 seconds due - // to the duplicate FIN we sent earlier. - *ackHeaders = *finHeaders - ackHeaders.SeqNum = ackHeaders.SeqNum + 1 - ackHeaders.Flags = header.TCPFlagAck - c.SendPacket(nil, ackHeaders) - - c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) - // Now sleep for another 2 seconds so that we are past the - // extended TIME_WAIT of 7 seconds (2 + 5). - time.Sleep(2 * time.Second) - - // Resend the same ACK. - c.SendPacket(nil, ackHeaders) - - // Receive the RST that should be generated as there is no valid - // endpoint. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } -} - -func TestTCPCloseWithData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: 30000, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now trigger a passive close by sending a FIN. - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - RcvWnd: 30000, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now write a few bytes and then close the endpoint. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } - - c.EP.Close() - // Check the FIN. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.TCPAckNum(uint32(iss+2)), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - // First send a partial ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now send a full ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now ACK the FIN. - ackHeaders.AckNum++ - c.SendPacket(nil, ackHeaders) - - // Now send an ACK and we should get a RST back as the endpoint should - // be in CLOSED state. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Check the RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - // Ensure that on the next retransmit timer fire, the user timeout has - // expired. - initRTO := 1 * time.Second - userTimeout := initRTO / 2 - v := tcpip.TCPUserTimeoutOption(userTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) - } - - // Send some data and wait before ACKing it. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Wait for the retransmit timer to be fired and the user timeout to cause - // close of the connection. - select { - case <-notifyCh: - case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) - } - - // No packet should be received as the connection should be silently - // closed due to timeout. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") - - next += uint32(len(view)) - - // The connection should be terminated after userTimeout has expired. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestKeepaliveWithUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) - } - keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) - } - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) - - // Set userTimeout to be the duration to be 1 keepalive - // probes. Which means that after the first probe is sent - // the second one should cause the connection to be - // closed due to userTimeout being hit. - userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&userTimeout); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) - } - - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Now receive 1 keepalives, but don't ACK it. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) - - // The connection should be closed with a timeout. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(c.IRS + 1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestIncreaseWindowOnRead(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after read() when the window grows by more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf * 2 - sent := 0 - data := make([]byte, defaultMTU/2) - - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Break once the window drops below defaultMTU/2 - if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { - break - } - } - - // We now have < 1 MSS in the buffer space. Read at least > 2 MSS - // worth of data as receive buffer space - w := tcpip.LimitedWriter{ - W: ioutil.Discard, - // defaultMTU is a good enough estimate for the MSS used for this - // connection. - N: defaultMTU * 2, - } - for w.N != 0 { - _, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - // After reading > MSS worth of data, we surely crossed MSS. See the ack: - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestIncreaseWindowOnBufferResize(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after available recv buffer grows to more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf - sent := 0 - data := make([]byte, defaultMTU/2) - - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), - checker.TCPWindowLessThanEq(0xffff), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Increasing the buffer from should generate an ACK, - // since window grew from small value to larger equal MSS - c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestTCPDeferAccept(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) - } - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give a bit of time for the socket to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestTCPDeferAcceptTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) - } - - // Sleep for a little of the tcpDeferAccept timeout. - time.Sleep(tcpDeferAccept + 100*time.Millisecond) - - // On timeout expiry we should get a SYN-ACK retransmission. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give sometime for the endpoint to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestResetDuringClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - iss := seqnum.Value(789) - c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) - // Send some data to make sure there is some unread - // data to trigger a reset on c.Close. - irs := c.IRS - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: irs.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(irs.Add(1))), - checker.TCPAckNum(uint32(iss.Add(5))))) - - // Close in a separate goroutine so that we can trigger - // a race with the RST we send below. This should not - // panic due to the route being released depeding on - // whether Close() sends an active RST or the RST sent - // below is processed by the worker first. - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(5), - AckNum: c.IRS.Add(5), - RcvWnd: 30000, - Flags: header.TCPFlagRst, - }) - }() - - wg.Add(1) - go func() { - defer wg.Done() - c.EP.Close() - }() - - wg.Wait() -} - -func TestStackTimeWaitReuse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - s := c.Stack() - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) - } - if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) - } -} - -func TestSetStackTimeWaitReuse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - s := c.Stack() - testCases := []struct { - v int - err tcpip.Error - }{ - {int(tcpip.TCPTimeWaitReuseDisabled), nil}, - {int(tcpip.TCPTimeWaitReuseGlobal), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, - } - - for _, tc := range testCases { - opt := tcpip.TCPTimeWaitReuseOption(tc.v) - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) - if got, want := err, tc.err; got != want { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) - } - if tc.err != nil { - continue - } - - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) - } - - if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) - } - } -} - -// generateRandomPayload generates a random byte slice of the specified length -// causing a fatal test failure if it is unable to do so. -func generateRandomPayload(t *testing.T, n int) []byte { - t.Helper() - buf := make([]byte, n) - if _, err := rand.Read(buf); err != nil { - t.Fatalf("rand.Read(buf) failed: %s", err) - } - return buf -} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go deleted file mode 100644 index 5a9745ad7..000000000 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -// createConnectedWithTimestampOption creates and connects c.ep with the -// timestamp option enabled. -func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) -} - -// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on -// an active connect and sets the TS Echo Reply fields correctly when the -// SYN-ACK also indicates support for the TS option and provides a TSVal. -func TestTimeStampEnabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read and validate that we have data to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // The following tests ensure that TS option once enabled behaves - // correctly as described in - // https://tools.ietf.org/html/rfc7323#section-4.3. - // - // We are not testing delayed ACKs here, but we do test out of order - // packet delivery and filling the sequence number hole created due to - // the out of order packet. - // - // The test also verifies that the sequence numbers and timestamps are - // as expected. - data := []byte{1, 2, 3} - - // First we increment tsVal by a small amount. - tsVal := rep.TSVal + 100 - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Next we send an out of order packet. - rep.NextSeqNum += 3 - tsVal += 200 - rep.SendPacketWithTS(data, tsVal) - - // The ACK should contain the original sequenceNumber and an older TS. - rep.NextSeqNum -= 6 - rep.VerifyACKWithTS(tsVal - 200) - - // Next we fill the hole and the returned ACK should contain the - // cumulative sequence number acking all data sent till now and have the - // latest timestamp sent below in its TSEcr field. - tsVal -= 100 - rep.SendPacketWithTS(data, tsVal) - rep.NextSeqNum += 3 - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal by a large value that doesn't result in a wrap around. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal again by a large value which should cause the - // timestamp value to wrap around. The returned ACK should contain the - // wrapped around timestamp in its tsEcr field and not the tsVal from - // the previous packet sent above. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // There should be 5 views to read and each of them should - // contain the same data. - for i := 0; i < 5; i++ { - buf := make([]byte, len(data)) - w := tcpip.SliceWriter(buf) - result, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(buf), - Total: len(buf), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if got, want := buf, data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } - } -} - -// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an -// active connect but if the SYN-ACK doesn't specify the TS option then -// timestamp option is not enabled and future packets do not contain a -// timestamp. -func TestTimeStampDisabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithOptions(header.TCPSynOptions{}) -} - -func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) - - // Now send some data and validate that timestamp is echoed correctly in the ACK. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %s", err) - } - - // Check that data is received and that the timestamp option TSEcr field - // matches the expected value. - b := c.GetPacket() - checker.IPv4(t, b, - // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 - // byte boundary. - checker.PayloadLen(len(data)+header.TCPMinimumSize+12), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(true, 0, tsVal+1), - ), - ) -} - -// TestTimeStampEnabledAccept tests that if the SYN on a passive connect -// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK -// and echoes the tsVal field of the original SYN in the tcEcr field of the -// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify -// that Timestamp option is enabled in both cases if requested in the original -// SYN. -func TestTimeStampEnabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now send some data with the accepted connection endpoint and validate - // that no timestamp option is sent in the TCP segment. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %s", err) - } - - // Check that data is received and that the timestamp option is disabled - // when SYN cookies are enabled/disabled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(false, 0, 0), - ), - ) -} - -// TestTimeStampDisabledAccept tests that Timestamp option is not used when the -// peer doesn't advertise it and connection is established with Accept(). -func TestTimeStampDisabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of - // that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func TestSendGreaterThanMTUWithOptions(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - createConnectedWithTimestampOption(c) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - droppedPacketsStat := c.Stack().Stats().DroppedPackets - droppedPackets := droppedPacketsStat.Value() - data := []byte{1, 2, 3} - // Send a packet with no TCP options/timestamp. - rep.SendPacket(data, nil) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Assert that DroppedPackets was not incremented. - if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { - t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) - } - - // Issue a read and we should data. - var buf bytes.Buffer - result, err := c.EP.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go new file mode 100644 index 000000000..4cb82fcc9 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tcp diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD deleted file mode 100644 index ce6a2c31d..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - testonly = 1, - srcs = ["context.go"], - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go deleted file mode 100644 index b1cb9a324..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ /dev/null @@ -1,1235 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provides a test context for use in tcp tests. It also -// provides helper methods to assert/check certain behaviours. -package context - -import ( - "bytes" - "context" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // StackAddr is the IPv4 address assigned to the stack. - StackAddr = "\x0a\x00\x00\x01" - - // StackPort is used as the listening port in tests for passive - // connects. - StackPort = 1234 - - // TestAddr is the source address for packets sent to the stack via the - // link layer endpoint. - TestAddr = "\x0a\x00\x00\x02" - - // TestPort is the TCP port used for packets sent to the stack - // via the link layer endpoint. - TestPort = 4096 - - // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - - // TestV6Addr is the source address for packets sent to the stack via - // the link layer endpoint. - TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - // StackV4MappedAddr is StackAddr as a mapped v6 address. - StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr - - // TestV4MappedAddr is TestAddr as a mapped v6 address. - TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr - - // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - // TestInitialSequenceNumber is the initial sequence number sent in packets that - // are sent in response to a SYN or in the initial SYN sent to the stack. - TestInitialSequenceNumber = 789 -) - -// StackAddrWithPrefix is StackAddr with its associated prefix length. -var StackAddrWithPrefix = tcpip.AddressWithPrefix{ - Address: StackAddr, - PrefixLen: 24, -} - -// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. -var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: StackV6Addr, - PrefixLen: header.IIDOffsetInIPv6Address * 8, -} - -// Headers is used to represent the TCP header fields when building a -// new packet. -type Headers struct { - // SrcPort holds the src port value to be used in the packet. - SrcPort uint16 - - // DstPort holds the destination port value to be used in the packet. - DstPort uint16 - - // SeqNum is the value of the sequence number field in the TCP header. - SeqNum seqnum.Value - - // AckNum represents the acknowledgement number field in the TCP header. - AckNum seqnum.Value - - // Flags are the TCP flags in the TCP header. - Flags int - - // RcvWnd is the window to be advertised in the ReceiveWindow field of - // the TCP header. - RcvWnd seqnum.Size - - // TCPOpts holds the options to be sent in the option field of the TCP - // header. - TCPOpts []byte -} - -// Options contains options for creating a new test context. -type Options struct { - // EnableV4 indicates whether IPv4 should be enabled. - EnableV4 bool - - // EnableV6 indicates whether IPv4 should be enabled. - EnableV6 bool - - // MTU indicates the maximum transmission unit on the link layer. - MTU uint32 -} - -// Context provides an initialized Network stack and a link layer endpoint -// for use in TCP tests. -type Context struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - // IRS holds the initial sequence number in the SYN sent by endpoint in - // case of an active connect or the sequence number sent by the endpoint - // in the SYN-ACK sent in response to a SYN when listening in passive - // mode. - IRS seqnum.Value - - // Port holds the port bound by EP below in case of an active connect or - // the listening port number in case of a passive connect. - Port uint16 - - // EP is the test endpoint in the stack owned by this context. This endpoint - // is used in various tests to either initiate an active connect or is used - // as a passive listening endpoint to accept inbound connections. - EP tcpip.Endpoint - - // Wq is the wait queue associated with EP and is used to block for events - // on EP. - WQ waiter.Queue - - // TimeStampEnabled is true if ep is connected with the timestamp option - // enabled. - TimeStampEnabled bool - - // WindowScale is the expected window scale in SYN packets sent by - // the stack. - WindowScale uint8 - - // RcvdWindowScale is the actual window scale sent by the stack in - // SYN/SYN-ACK. - RcvdWindowScale uint8 -} - -// New allocates and initializes a test context containing a new -// stack and a link-layer endpoint. -func New(t *testing.T, mtu uint32) *Context { - return NewWithOpts(t, Options{ - EnableV4: true, - EnableV6: true, - MTU: mtu, - }) -} - -// NewWithOpts allocates and initializes a test context containing a new -// stack and a link-layer endpoint with specific options. -func NewWithOpts(t *testing.T, opts Options) *Context { - if opts.MTU == 0 { - panic("MTU must be greater than 0") - } - - stackOpts := stack.Options{ - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - } - if opts.EnableV4 { - stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) - } - if opts.EnableV6 { - stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) - } - s := stack.New(stackOpts) - - const sendBufferSize = 1 << 20 // 1 MiB - const recvBufferSize = 1 << 20 // 1 MiB - // Allow minimum send/receive buffer sizes to be 1 during tests. - sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) - } - - rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) - } - - // Increase minimum RTO in tests to avoid test flakes due to early - // retransmit in case the test executors are overloaded and cause timers - // to fire earlier than expected. - minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) - } - - // Some of the congestion control tests send up to 640 packets, we so - // set the channel size to 1000. - ep := channel.New(1000, opts.MTU, "") - wep := stack.LinkEndpoint(ep) - if testing.Verbose() { - wep = sniffer.New(ep) - } - nicOpts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) - if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) - } - opts2 := stack.NICOptions{Name: "nic2"} - if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) - } - - var routeTable []tcpip.Route - - if opts.EnableV4 { - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - routeTable = append(routeTable, tcpip.Route{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }) - } - - if opts.EnableV6 { - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } - routeTable = append(routeTable, tcpip.Route{ - Destination: header.IPv6EmptySubnet, - NIC: 1, - }) - } - - s.SetRouteTable(routeTable) - - return &Context{ - t: t, - s: s, - linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), - } -} - -// Cleanup closes the context endpoint if required. -func (c *Context) Cleanup() { - if c.EP != nil { - c.EP.Close() - } - c.Stack().Close() -} - -// Stack returns a reference to the stack in the Context. -func (c *Context) Stack() *stack.Stack { - return c.s -} - -// CheckNoPacketTimeout verifies that no packet is received during the time -// specified by wait. -func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), wait) - defer cancel() - if _, ok := c.linkEP.ReadContext(ctx); ok { - c.t.Fatal(errMsg) - } -} - -// CheckNoPacket verifies that no packet is received for 1 second. -func (c *Context) CheckNoPacket(errMsg string) { - c.CheckNoPacketTimeout(errMsg, 1*time.Second) -} - -// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. If no packet is received in the specified timeout it will return -// nil. -func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - // Just check that the stack set the transport protocol number for outbound - // TCP messages. - // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part - // of the headerinfo. - if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { - c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// GetPacket reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. -func (c *Context) GetPacket() []byte { - c.t.Helper() - - p := c.GetPacketWithTimeout(5 * time.Second) - if p == nil { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - return p -} - -// GetPacketNonBlocking reads a packet from the link layer endpoint -// and verifies that it is an IPv4 packet with the expected source -// and destination address. If no packet is available it will return -// nil immediately. -func (c *Context) GetPacketNonBlocking() []byte { - c.t.Helper() - - p, ok := c.linkEP.Read() - if !ok { - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - // Just check that the stack set the transport protocol number for outbound - // TCP messages. - // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part - // of the headerinfo. - if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { - c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) { - // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) - if len(buf) > maxTotalSize { - buf = buf[:maxTotalSize] - } - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) - icmp.SetType(typ) - icmp.SetCode(code) - const icmpv4VariableHeaderOffset = 4 - copy(icmp[icmpv4VariableHeaderOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset:], p2) - icmp.SetChecksum(0) - checksum := ^header.Checksum(icmp, 0 /* initial */) - icmp.SetChecksum(checksum) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// BuildSegment builds a TCP segment based on the given Headers and payload. -func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { - return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) -} - -// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, -// payload and source and destination IPv4 addresses. -func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv4MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - return buf.ToVectorisedView() -} - -// SendSegment sends a TCP segment that has already been built and written to a -// buffer.VectorisedView. -func (c *Context) SendSegment(s buffer.VectorisedView) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: s, - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendPacket builds and sends a TCP segment(with the provided payload & TCP -// headers) in an IPv4 packet via the link layer endpoint. -func (c *Context) SendPacket(payload []byte, h *Headers) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: c.BuildSegment(payload, h), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload -// & TCPheaders) in an IPv4 packet via the link layer endpoint using the -// provided source and destination IPv4 addresses. -func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: c.BuildSegmentWithAddrs(payload, h, src, dst), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendAck sends an ACK packet. -func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { - c.SendAckWithSACK(seq, bytesReceived, nil) -} - -// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. -func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { - options := make([]byte, 40) - offset := 0 - if len(sackBlocks) > 0 { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) - } - - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options[:offset], - }) -} - -// ReceiveAndCheckPacket reads a packet from the link layer endpoint and -// verifies that the packet packet payload of packet matches the slice -// of data indicated by offset & size. -func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { - c.t.Helper() - - c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) -} - -// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size and skips optlen bytes in addition to the IP -// TCP headers when comparing the data. -func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { - c.t.Helper() - - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize+optlen), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } -} - -// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size. It returns true if a packet was received and -// processed. -func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { - c.t.Helper() - - b := c.GetPacketNonBlocking() - if b == nil { - return false - } - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } - return true -} - -// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only -// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 -// only endpoint instead of a default dual stack socket. -func (c *Context) CreateV6Endpoint(v6only bool) { - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - c.EP.SocketOptions().SetV6Only(v6only) -} - -// GetV6Packet reads a single packet from the link layer endpoint of the context -// and asserts that it is an IPv6 Packet with the expected src/dest addresses. -func (c *Context) GetV6Packet() []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b -} - -// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of -// the context. -func (c *Context) SendV6Packet(payload []byte, h *Headers) { - c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) -} - -// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer -// endpoint of the context using the provided source and destination IPv6 -// addresses. -func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - TransportProtocol: tcp.ProtocolNumber, - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, - }) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv6MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) -} - -// CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { - c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) -} - -// Connect performs the 3-way handshake for c.EP with the provided Initial -// Sequence Number (iss) and receive window(rcvWnd) and any options if -// specified. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -// -// PreCondition: c.EP must already be created. -func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { - c.t.Helper() - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - c.t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - c.SendPacket(nil, &Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: options, - }) - - // Receive ACK packet. - checker.IPv4(c.t, c.GetPacket(), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.LastError(); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.RcvdWindowScale = uint8(synOpts.WS) - c.Port = tcpHdr.SourcePort() -} - -// Create creates a TCP endpoint. -func (c *Context) Create(epRcvBuf int) { - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if epRcvBuf != -1 { - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } -} - -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { - c.Create(epRcvBuf) - c.Connect(iss, rcvWnd, options) -} - -// RawEndpoint is just a small wrapper around a TCP endpoint's state to make -// sending data and ACK packets easy while being able to manipulate the sequence -// numbers and timestamp values as needed. -type RawEndpoint struct { - C *Context - SrcPort uint16 - DstPort uint16 - Flags int - NextSeqNum seqnum.Value - AckNum seqnum.Value - WndSize seqnum.Size - RecentTS uint32 // Stores the latest timestamp to echo back. - TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. - - // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. - SACKPermitted bool -} - -// SendPacketWithTS embeds the provided tsVal in the Timestamp option -// for the packet to be sent out. -func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { - r.TSVal = tsVal - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) - r.SendPacket(payload, tsOpt[:]) -} - -// SendPacket is a small wrapper function to build and send packets. -func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { - packetHeaders := &Headers{ - SrcPort: r.SrcPort, - DstPort: r.DstPort, - Flags: r.Flags, - SeqNum: r.NextSeqNum, - AckNum: r.AckNum, - RcvWnd: r.WndSize, - TCPOpts: opts, - } - r.C.SendPacket(payload, packetHeaders) - r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) -} - -// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches -// the provided tsVal as well as returns the original packet. -func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { - r.C.t.Helper() - // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPTimestampChecker(true, 0, tsVal), - ), - ) - // Store the parsed TSVal from the ack as recentTS. - tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) - opts := tcpSeg.ParsedOptions() - r.RecentTS = opts.TSVal - return ackPacket -} - -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { - r.C.t.Helper() - _ = r.VerifyAndReturnACKWithTS(tsVal) -} - -// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK -// matches the provided rcvWnd. -func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { - r.C.t.Helper() - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPWindow(rcvWnd), - ), - ) -} - -// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. -func (r *RawEndpoint) VerifyACKNoSACK() { - r.VerifyACKHasSACK(nil) -} - -// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. -func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { - // Read ACK and verify that the TCP options in the segment do - // not contain a SACK block. - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPSACKBlockChecker(sackBlocks), - ), - ) -} - -// CreateConnectedWithOptions creates and connects c.ep with the specified TCP -// options enabled and returns a RawEndpoint which represents the other end of -// the connection. -// -// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK -// does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} - err = c.EP.Connect(testFullAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) - } - // Receive SYN packet. - b := c.GetPacket() - // Validate that the syn has the timestamp option and a valid - // TS value. - mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: int(c.WindowScale), - SACKPermitted: c.SACKEnabled(), - }), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpSeg := header.TCP(header.IPv4(b).Payload()) - synOptions := header.ParseSynOptions(tcpSeg.Options(), false) - - // Build options w/ tsVal to be sent in the SYN-ACK. - synAckOptions := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - if wantOptions.WS != -1 { - offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:]) - } - if wantOptions.TS { - offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) - } - if wantOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) - } - - offset += header.AddTCPOptionPadding(synAckOptions, offset) - - // Build SYN-ACK. - c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(TestInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - TCPOpts: synAckOptions[:offset], - }) - - // Read ACK. - ackPacket := c.GetPacket() - - // Verify TCP header fields. - tcpCheckers := []checker.TransportChecker{ - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS) + 1), - checker.TCPAckNum(uint32(iss) + 1), - } - - // Verify that tsEcr of ACK packet is wantOptions.TSVal if the - // timestamp option was enabled, if not then we verify that - // there is no timestamp in the ACK packet. - if wantOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) - - ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) - ackOptions := ackSeg.ParsedOptions() - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.LastError(); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Store the source port in use by the endpoint. - c.Port = tcpSeg.SourcePort() - - // Mark in context that timestamp option is enabled for this endpoint. - c.TimeStampEnabled = true - c.RcvdWindowScale = uint8(synOptions.WS) - return &RawEndpoint{ - C: c, - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagAck | header.TCPFlagPsh, - NextSeqNum: iss + 1, - AckNum: c.IRS.Add(1), - WndSize: 30000, - RecentTS: ackOptions.TSVal, - TSVal: wantOptions.TSVal, - SACKPermitted: wantOptions.SACKPermitted, - } -} - -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. -// -// The function returns a RawEndpoint representing the other end of the accepted -// endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if err := ep.Listen(10); err != nil { - c.t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - c.t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - return rep -} - -// PassiveConnect just disables WindowScaling and delegates the call to -// PassiveConnectWithOptions. -func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { - synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) -} - -// PassiveConnectWithOptions initiates a new connection (with the specified TCP -// options enabled) to the port on which the Context.ep is listening for new -// connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. -// -// NOTE: MSS is not a negotiated option and it can be asymmetric -// in each direction. This function uses the maxPayload to set the MSS to be -// sent to the peer on a connect and validates that the MSS in the SYN-ACK -// response is equal to the MTU - (tcphdr len + iphdr len). -// -// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the -// value of the window scaling option to be sent in the SYN. If synOptions.WS > -// 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - c.t.Helper() - opts := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - offset += header.EncodeMSSOption(uint32(maxPayload), opts) - - if synOptions.WS >= 0 { - offset += header.EncodeWSOption(3, opts[offset:]) - } - if synOptions.TS { - offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) - } - - if synOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(opts[offset:]) - } - - paddingToAdd := 4 - offset%4 - // Now add any padding bytes that might be required to quad align the - // options. - for i := offset; i < offset+paddingToAdd; i++ { - opts[i] = header.TCPOptionNOP - } - offset += paddingToAdd - - // Send a SYN request. - iss := seqnum.Value(TestInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - TCPOpts: opts[:offset], - }) - - // Receive the SYN-ACK reply. Make sure MSS and other expected options - // are present. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(StackPort), - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(iss) + 1), - checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), - } - - // If TS option was enabled in the original SYN then add a checker to - // validate the Timestamp option in the SYN-ACK. - if synOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) - rcvWnd := seqnum.Size(30000) - ackHeaders := &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: rcvWnd, - } - - // If WS was expected to be in effect then scale the advertised window - // correspondingly. - if synOptions.WS > 0 { - ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) - } - - parsedOpts := tcp.ParsedOptions() - if synOptions.TS { - // Echo the tsVal back to the peer in the tsEcr field of the - // timestamp option. - // Increment TSVal by 1 from the value sent in the SYN and echo - // the TSVal in the SYN-ACK in the TSEcr field. - opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) - ackHeaders.TCPOpts = opts[:] - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.RcvdWindowScale = uint8(rcvdSynOptions.WS) - c.Port = StackPort - - return &RawEndpoint{ - C: c, - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagPsh | header.TCPFlagAck, - NextSeqNum: iss + 1, - AckNum: c.IRS + 1, - WndSize: rcvWnd, - SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), - RecentTS: parsedOpts.TSVal, - TSVal: synOptions.TSVal + 1, - } -} - -// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true -// for the Stack in the context. -func (c *Context) SACKEnabled() bool { - var v tcpip.TCPSACKEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { - // Stack doesn't support SACK. So just return. - return false - } - return bool(v) -} - -// SetGSOEnabled enables or disables generic segmentation offload. -func (c *Context) SetGSOEnabled(enable bool) { - if enable { - c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO - } else { - c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO - } -} - -// MSSWithoutOptions returns the value for the MSS used by the stack when no -// options are in use. -func (c *Context) MSSWithoutOptions() uint16 { - return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) -} - -// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no -// options are in use for IPv6 packets. -func (c *Context) MSSWithoutOptionsV6() uint16 { - return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go deleted file mode 100644 index dbd6dff54..000000000 --- a/pkg/tcpip/transport/tcp/timer_test.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" -) - -func TestCleanup(t *testing.T) { - const ( - timerDurationSeconds = 2 - isAssertedTimeoutSeconds = timerDurationSeconds + 1 - ) - - tmr := timer{} - w := sleep.Waker{} - tmr.init(&w) - tmr.enable(timerDurationSeconds * time.Second) - tmr.cleanup() - - if want := (timer{}); tmr != want { - t.Errorf("got tmr = %+v, want = %+v", tmr, want) - } - - // The waker should not be asserted. - for i := 0; i < isAssertedTimeoutSeconds; i++ { - time.Sleep(time.Second) - if w.IsAsserted() { - t.Fatalf("waker asserted unexpectedly") - } - } -} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD deleted file mode 100644 index 3ad6994a7..000000000 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tcpconntrack", - srcs = ["tcp_conntrack.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcpconntrack_test", - size = "small", - srcs = ["tcp_conntrack_test.go"], - deps = [ - ":tcpconntrack", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go deleted file mode 100644 index 5e271b7ca..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpconntrack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" -) - -// connected creates a connection tracker TCB and sets it to a connected state -// by performing a 3-way handshake. -func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: iss, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: irw, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: irs, - AckNum: iss + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: isw, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: iss + 1, - AckNum: irs + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: irw, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - return &tcb -} - -func TestConnectionRefused(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive RST. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionRefusedInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionResetInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestRetransmitOnSynSent(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Retransmit the same SYN. - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) - } -} - -func TestRetransmitOnSynRcvd(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. This will cause the state to go to SYN-RCVD. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Retransmit the original SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Transmit a SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} - -func TestClosedBySelf(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Send FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1236, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestClosedByPeer(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Receive FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 791, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) - } -} - -func TestSendAndReceiveDataClosedBySelf(t *testing.T) { - sseq := uint32(1234) - rseq := uint32(789) - tcb := connected(t, sseq, rseq, 30000, 50000) - sseq++ - rseq++ - - // Send some data. - tcp := make(header.TCP, header.TCPMinimumSize+1024) - - for i := uint32(0); i < 10; i++ { - // Send some data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - sseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - for i := uint32(0); i < 10; i++ { - // Receive some data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - rseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - // Send FIN. - tcp = tcp[:header.TCPMinimumSize] - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - sseq++ - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - rseq++ - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestIgnoreBadResetOnSynSent(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive a RST with a bad ACK, it should not cause the connection to - // be reset. - acks := []uint32{1234, 1236, 1000, 5000} - flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} - for _, a := range acks { - for _, f := range flags { - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: a, - DataOffset: header.TCPMinimumSize, - Flags: f, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - } - - // Complete the handshake. - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go new file mode 100644 index 000000000..ff53204da --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tcpconntrack diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD deleted file mode 100644 index 153e8c950..000000000 --- a/pkg/tcpip/transport/udp/BUILD +++ /dev/null @@ -1,63 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "udp_packet_list", - out = "udp_packet_list.go", - package = "udp", - prefix = "udpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*udpPacket", - "Linker": "*udpPacket", - }, -) - -go_library( - name = "udp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "udp_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - ], -) - -go_test( - name = "udp_x_test", - size = "small", - srcs = ["udp_test.go"], - deps = [ - ":udp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100644 index 000000000..c396f77c9 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,221 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *udpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (udpPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *udpPacketList) PushFront(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *udpPacketList) PushBack(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + bLinker := udpPacketElementMapper{}.linkerFor(b) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + aLinker := udpPacketElementMapper{}.linkerFor(a) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *udpPacketList) Remove(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100644 index 000000000..16900d0f9 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,229 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (u *udpPacket) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacket" +} + +func (u *udpPacket) StateFields() []string { + return []string{ + "udpPacketEntry", + "senderAddress", + "destinationAddress", + "packetInfo", + "data", + "timestamp", + "tos", + } +} + +func (u *udpPacket) beforeSave() {} + +func (u *udpPacket) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + var dataValue buffer.VectorisedView = u.saveData() + stateSinkObject.SaveValue(4, dataValue) + stateSinkObject.Save(0, &u.udpPacketEntry) + stateSinkObject.Save(1, &u.senderAddress) + stateSinkObject.Save(2, &u.destinationAddress) + stateSinkObject.Save(3, &u.packetInfo) + stateSinkObject.Save(5, &u.timestamp) + stateSinkObject.Save(6, &u.tos) +} + +func (u *udpPacket) afterLoad() {} + +func (u *udpPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.udpPacketEntry) + stateSourceObject.Load(1, &u.senderAddress) + stateSourceObject.Load(2, &u.destinationAddress) + stateSourceObject.Load(3, &u.packetInfo) + stateSourceObject.Load(5, &u.timestamp) + stateSourceObject.Load(6, &u.tos) + stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { u.loadData(y.(buffer.VectorisedView)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/udp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "rcvReady", + "rcvList", + "rcvBufSizeMax", + "rcvBufSize", + "rcvClosed", + "state", + "dstPort", + "ttl", + "multicastTTL", + "multicastAddr", + "multicastNICID", + "portFlags", + "lastError", + "boundBindToDevice", + "boundPortFlags", + "sendTOS", + "shutdownFlags", + "multicastMemberships", + "effectiveNetProtos", + "owner", + "ops", + } +} + +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() + stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.uniqueID) + stateSinkObject.Save(4, &e.rcvReady) + stateSinkObject.Save(5, &e.rcvList) + stateSinkObject.Save(7, &e.rcvBufSize) + stateSinkObject.Save(8, &e.rcvClosed) + stateSinkObject.Save(9, &e.state) + stateSinkObject.Save(10, &e.dstPort) + stateSinkObject.Save(11, &e.ttl) + stateSinkObject.Save(12, &e.multicastTTL) + stateSinkObject.Save(13, &e.multicastAddr) + stateSinkObject.Save(14, &e.multicastNICID) + stateSinkObject.Save(15, &e.portFlags) + stateSinkObject.Save(16, &e.lastError) + stateSinkObject.Save(17, &e.boundBindToDevice) + stateSinkObject.Save(18, &e.boundPortFlags) + stateSinkObject.Save(19, &e.sendTOS) + stateSinkObject.Save(20, &e.shutdownFlags) + stateSinkObject.Save(21, &e.multicastMemberships) + stateSinkObject.Save(22, &e.effectiveNetProtos) + stateSinkObject.Save(23, &e.owner) + stateSinkObject.Save(24, &e.ops) +} + +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.uniqueID) + stateSourceObject.Load(4, &e.rcvReady) + stateSourceObject.Load(5, &e.rcvList) + stateSourceObject.Load(7, &e.rcvBufSize) + stateSourceObject.Load(8, &e.rcvClosed) + stateSourceObject.Load(9, &e.state) + stateSourceObject.Load(10, &e.dstPort) + stateSourceObject.Load(11, &e.ttl) + stateSourceObject.Load(12, &e.multicastTTL) + stateSourceObject.Load(13, &e.multicastAddr) + stateSourceObject.Load(14, &e.multicastNICID) + stateSourceObject.Load(15, &e.portFlags) + stateSourceObject.Load(16, &e.lastError) + stateSourceObject.Load(17, &e.boundBindToDevice) + stateSourceObject.Load(18, &e.boundPortFlags) + stateSourceObject.Load(19, &e.sendTOS) + stateSourceObject.Load(20, &e.shutdownFlags) + stateSourceObject.Load(21, &e.multicastMemberships) + stateSourceObject.Load(22, &e.effectiveNetProtos) + stateSourceObject.Load(23, &e.owner) + stateSourceObject.Load(24, &e.ops) + stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (m *multicastMembership) StateTypeName() string { + return "pkg/tcpip/transport/udp.multicastMembership" +} + +func (m *multicastMembership) StateFields() []string { + return []string{ + "nicID", + "multicastAddr", + } +} + +func (m *multicastMembership) beforeSave() {} + +func (m *multicastMembership) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.nicID) + stateSinkObject.Save(1, &m.multicastAddr) +} + +func (m *multicastMembership) afterLoad() {} + +func (m *multicastMembership) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.nicID) + stateSourceObject.Load(1, &m.multicastAddr) +} + +func (l *udpPacketList) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacketList" +} + +func (l *udpPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *udpPacketList) beforeSave() {} + +func (l *udpPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *udpPacketList) afterLoad() {} + +func (l *udpPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *udpPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacketEntry" +} + +func (e *udpPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *udpPacketEntry) beforeSave() {} + +func (e *udpPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *udpPacketEntry) afterLoad() {} + +func (e *udpPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*udpPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*multicastMembership)(nil)) + state.Register((*udpPacketList)(nil)) + state.Register((*udpPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go deleted file mode 100644 index 5d81dbb94..000000000 --- a/pkg/tcpip/transport/udp/udp_test.go +++ /dev/null @@ -1,2565 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_test - -import ( - "bytes" - "context" - "fmt" - "io/ioutil" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// Addresses and ports used for testing. It is recommended that tests stick to -// using these addresses as it allows using the testFlow helper. -// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' -// represents the remote endpoint. -const ( - v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = v4MappedAddrPrefix + stackAddr - testV4MappedAddr = v4MappedAddrPrefix + testAddr - multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr - broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr - v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 - invalidPort = 8192 - multicastAddr = "\xe8\x2b\xd3\xea" - multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - broadcastAddr = header.IPv4Broadcast - testTOS = 0x80 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in -// a packet header. These values are used to populate a header or verify one. -// Note that because they are used in packet headers, the addresses are never in -// a V4-mapped format. -type header4Tuple struct { - srcAddr tcpip.FullAddress - dstAddr tcpip.FullAddress -} - -// testFlow implements a helper type used for sending and receiving test -// packets. A given test flow value defines 1) the socket endpoint used for the -// test and 2) the type of packet send or received on the endpoint. E.g., a -// multicastV6Only flow is a V6 multicast packet passing through a V6-only -// endpoint. The type provides helper methods to characterize the flow (e.g., -// isV4) as well as return a proper header4Tuple for it. -type testFlow int - -const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket - reverseMulticast4 // V4 multicast src. Must fail. - reverseMulticast6 // V6 multicast src. Must fail. -) - -func (flow testFlow) String() string { - switch flow { - case unicastV4: - return "unicastV4" - case unicastV6: - return "unicastV6" - case unicastV6Only: - return "unicastV6Only" - case unicastV4in6: - return "unicastV4in6" - case multicastV4: - return "multicastV4" - case multicastV6: - return "multicastV6" - case multicastV6Only: - return "multicastV6Only" - case multicastV4in6: - return "multicastV4in6" - case broadcast: - return "broadcast" - case broadcastIn6: - return "broadcastIn6" - case reverseMulticast4: - return "reverseMulticast4" - case reverseMulticast6: - return "reverseMulticast6" - default: - return "unknown" - } -} - -// packetDirection explains if a flow is incoming (read) or outgoing (write). -type packetDirection int - -const ( - incoming packetDirection = iota - outgoing -) - -// header4Tuple returns the header4Tuple for the given flow and direction. Note -// that the tuple contains no mapped addresses as those only exist at the socket -// level but not at the packet header level. -func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { - var h header4Tuple - if flow.isV4() { - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastAddr - } else if flow.isBroadcast() { - h.dstAddr.Addr = broadcastAddr - } - } else { // IPv6 - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastV6Addr - } - } - if flow.isReverseMulticast() { - h.srcAddr.Addr = flow.getMcastAddr() - } - return h -} - -func (flow testFlow) getMcastAddr() tcpip.Address { - if flow.isV4() { - return multicastAddr - } - return multicastV6Addr -} - -// mapAddrIfApplicable converts the given V4 address into its V4-mapped version -// if it is applicable to the flow. -func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { - if flow.isMapped() { - return v4MappedAddrPrefix + v4Addr - } - return v4Addr -} - -// netProto returns the protocol number used for the network packet. -func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { - if flow.isV4() { - return ipv4.ProtocolNumber - } - return ipv6.ProtocolNumber -} - -// sockProto returns the protocol number used when creating the socket -// endpoint for this flow. -func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { - switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: - return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast, reverseMulticast4: - return ipv4.ProtocolNumber - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { - if flow.isV4() { - return checker.IPv4 - } - return checker.IPv6 -} - -func (flow testFlow) isV6() bool { return !flow.isV4() } -func (flow testFlow) isV4() bool { - return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() -} - -func (flow testFlow) isV6Only() bool { - switch flow { - case unicastV6Only, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMulticast() bool { - switch flow { - case multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isBroadcast() bool { - switch flow { - case broadcast, broadcastIn6: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMapped() bool { - switch flow { - case unicastV4in6, multicastV4in6, broadcastIn6: - return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isReverseMulticast() bool { - switch flow { - case reverseMulticast4, reverseMulticast6: - return true - default: - return false - } -} - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func newDualTestContext(t *testing.T, mtu uint32) *testContext { - t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: true, - }) -} - -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { - t.Helper() - - s := stack.New(options) - ep := channel.New(256, mtu, "") - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %s", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %s", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { - c.t.Helper() - - var err tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) - if err != nil { - c.t.Fatal("NewEndpoint failed: ", err) - } -} - -func (c *testContext) createEndpointForFlow(flow testFlow) { - c.t.Helper() - - c.createEndpoint(flow.sockProto()) - if flow.isV6Only() { - c.ep.SocketOptions().SetV6Only(true) - } else if flow.isBroadcast() { - c.ep.SocketOptions().SetBroadcast(true) - } -} - -// getPacketAndVerify reads a packet from the link endpoint and verifies the -// header against expected values from the given test flow. In addition, it -// calls any extra checker functions provided. -func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { - c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - h := flow.header4Tuple(outgoing) - checkers = append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b -} - -// injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. If badChecksum is true, the packet has -// a bad checksum in the UDP header. -func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) { - c.t.Helper() - - h := flow.header4Tuple(incoming) - if flow.isV4() { - buf := c.buildV4Packet(payload, &h) - if badChecksum { - // Invalidate the UDP header checksum field, taking care to avoid - // overflow to zero, which would disable checksum validation. - for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { - u.SetChecksum(u.Checksum() + 1) - if u.Checksum() != 0 { - break - } - } - } - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } else { - buf := c.buildV6Packet(payload, &h) - if badChecksum { - // Invalidate the UDP header checksum field (Unlike IPv4, zero is - // a valid checksum value for IPv6 so no need to avoid it). - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) - } - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } -} - -// buildV6Packet creates a V6 test packet with the given payload and header -// values in a buffer. -func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - return buf -} - -// buildV4Packet creates a V4 test packet with the given payload and header -// values in a buffer. -func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TOS: testTOS, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - return buf -} - -func newPayload() []byte { - return newMinPayload(30) -} - -func newMinPayload(minSize int) []byte { - b := make([]byte, minSize+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError tcpip.Error - getBindToDevice int32 - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := int32(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) - } - } - bindToDevice := ep.SocketOptions().GetBindToDevice() - if bindToDevice != testAction.getBindToDevice { - t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) - } - }) - } -} - -// testReadInternal sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then attempts to read it from the -// UDP endpoint and depending on if this was expected to succeed verifies its -// correctness including any additional checker functions provided. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - - payload := newPayload() - c.injectPacket(flow, payload, false) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - var buf bytes.Buffer - res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for data to become available. - select { - case <-ch: - res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - - case <-time.After(300 * time.Millisecond): - if packetShouldBeDropped { - return // expected to time out - } - c.t.Fatal("timed out waiting for data") - } - } - - if expectReadError && err != nil { - c.checkEndpointReadStats(1, epstats, err) - return - } - - if err != nil { - c.t.Fatal("Read failed:", err) - } - - if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr) - } - - // Check the read result. - h := flow.header4Tuple(incoming) - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", // ControlMessages will be checked later. - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff) - } - - // Check the payload. - v := buf.Bytes() - if !bytes.Equal(payload, v) { - c.t.Fatalf("got payload = %x, want = %x", v, payload) - } - - // Run any checkers against the ControlMessages. - for _, f := range checkers { - f(c.t, res.ControlMessages) - } - - c.checkEndpointReadStats(1, epstats, err) -} - -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness including any additional checker functions provided. -func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) -} - -// testFailingRead sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then tries to read it from the UDP -// endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { - c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) -} - -func TestBindEphemeralPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } -} - -func TestBindReservedPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - addr, err := c.ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %s", err) - } - - // We can't bind the address reserved by the connected endpoint above. - { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - { - err := ep.Bind(addr) - if _, ok := err.(*tcpip.ErrPortInUse); !ok { - t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) - } - } - } - - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - // We can't bind ipv4-any on the port reserved by the connected endpoint - // above, since the endpoint is dual-stack. - { - err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) - if _, ok := err.(*tcpip.ErrPortInUse); !ok { - t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) - } - } - // We can bind an ipv4 address on this port, though. - if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } - }() - - // Once the connected endpoint releases its port reservation, we are able to - // bind ipv4-any once again. - c.ep.Close() - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } - }() -} - -func TestV4ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to local address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV6ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV6) -} - -// TestV4ReadSelfSource checks that packets coming from a local IP address are -// correctly dropped when handleLocal is true and not otherwise. -func TestV4ReadSelfSource(t *testing.T) { - for _, tt := range []struct { - name string - handleLocal bool - wantErr tcpip.Error - wantInvalidSource uint64 - }{ - {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, - } { - t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: tt.handleLocal, - }) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - h.srcAddr = h.dstAddr - - buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { - t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) - } - - if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr { - t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr) - } - }) - } -} - -func TestV4ReadOnV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4) -} - -// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast -// address and receive data sent to that address. -func TestReadOnBoundToMulticast(t *testing.T) { - // FIXME(b/128189410): multicastV4in6 currently doesn't work as - // AddMembershipOption doesn't handle V4in6 addresses. - for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to multicast address. - mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) - if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - // Join multicast group. - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) - } - - // Check that we receive multicast packets but not unicast or broadcast - // ones. - testRead(c, flow) - testFailingRead(c, broadcast, false /* expectReadError */) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and can receive only broadcast data. -func TestV4ReadOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to broadcast address. - bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) - if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Check that we receive broadcast packets but not unicast ones. - testRead(c, flow) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestReadFromMulticast checks that an endpoint will NOT receive a packet -// that was sent with multicast SOURCE address. -func TestReadFromMulticast(t *testing.T) { - for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - testFailingRead(c, flow, false /* expectReadError */) - }) - } -} - -// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY -// and receive broadcast and unicast data. -func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s (", err) - } - - // Check that we receive both broadcast and unicast packets. - testRead(c, flow) - testRead(c, unicastV4) - }) - } -} - -// testFailingWrite sends a packet of the given test flow into the UDP endpoint -// and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - - var r bytes.Reader - r.Reset(newPayload()) - _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - }) - c.checkEndpointWriteStats(1, epstats, gotErr) - if gotErr != wantErr { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) - } -} - -// testWrite sends a packet of the given test flow from the UDP endpoint to the -// flow's destination address:port. It then receives it from the link endpoint -// and verifies its correctness including any additional checker functions -// provided. -func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteAndVerifyInternal(c, flow, true, checkers...) -} - -// testWriteWithoutDestination sends a packet of the given test flow from the -// UDP endpoint without giving a destination address:port. It then receives it -// from the link endpoint and verifies its correctness including any additional -// checker functions provided. -func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteAndVerifyInternal(c, flow, false, checkers...) -} - -func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - writeOpts := tcpip.WriteOptions{} - if setDest { - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - writeOpts = tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - } - } - var r bytes.Reader - payload := newPayload() - r.Reset(payload) - n, err := c.ep.Write(&r, writeOpts) - if err != nil { - c.t.Fatalf("Write failed: %s", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - c.checkEndpointWriteStats(1, epstats, err) - return payload -} - -func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - payload := testWriteNoVerify(c, flow, setDest) - // Received the packet and check the payload. - b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP - if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) - } else { - udp = header.UDP(header.IPv6(b).Payload()) - } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } - - return udp.SourcePort() -} - -func testDualWrite(c *testContext) uint16 { - c.t.Helper() - - v4Port := testWrite(c, unicastV4in6) - v6Port := testWrite(c, unicastV6) - if v4Port != v6Port { - c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) - } - - return v4Port -} - -func TestDualWriteUnbound(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) -} - -func TestDualWriteBoundToWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - p := testDualWrite(c) - if p != stackPort { - c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) - } -} - -func TestDualWriteConnectedToV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, unicastV6) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) - const want = 1 - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { - c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) - } -} - -func TestDualWriteConnectedToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, unicastV4in6) - - // Write to v6 address. - testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) -} - -func TestV4WriteOnV6Only(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6Only) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) -} - -func TestV6WriteOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to v4 mapped address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Write to v6 address. - testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) -} - -func TestV6WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - testWriteWithoutDestination(c, unicastV6) -} - -func TestV4WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - testWriteWithoutDestination(c, unicastV4) -} - -func TestWriteOnConnectedInvalidPort(t *testing.T) { - protocols := map[string]tcpip.NetworkProtocolNumber{ - "ipv4": ipv4.ProtocolNumber, - "ipv6": ipv6.ProtocolNumber, - } - for name, pn := range protocols { - t.Run(name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(pn) - if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - writeOpts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, - } - var r bytes.Reader - payload := newPayload() - r.Reset(payload) - n, err := c.ep.Write(&r, writeOpts) - if err != nil { - c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) - } - if got, want := n, int64(len(payload)); got != want { - c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) - } - - { - err := c.ep.LastError() - if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { - c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) - } - } - }) - } -} - -// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket -// that is bound to a V4 multicast address. -func TestWriteOnBoundToV4Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a -// socket that is bound to a V4-mapped multicast address. -func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV6, multicastV6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// V6-only socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToBroadcast checks that we can send packets out of a -// socket that is bound to the broadcast address. -func TestWriteOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 broadcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a -// socket that is bound to the V4-mapped broadcast address. -func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -func TestReadIncrementsPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create IPv4 UDP endpoint - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testRead(c, unicastV4) - - var want uint64 = 1 - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { - c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) - } -} - -func TestReadIPPacketInfo(t *testing.T) { - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedLocalAddr tcpip.Address - expectedDestAddr tcpip.Address - }{ - { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedLocalAddr: stackAddr, - expectedDestAddr: stackAddr, - }, - { - name: "IPv4 multicast", - proto: header.IPv4ProtocolNumber, - flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastAddr, - expectedDestAddr: multicastAddr, - }, - { - name: "IPv4 broadcast", - proto: header.IPv4ProtocolNumber, - flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: broadcastAddr, - expectedDestAddr: broadcastAddr, - }, - { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedLocalAddr: stackV6Addr, - expectedDestAddr: stackV6Addr, - }, - { - name: "IPv6 multicast", - proto: header.IPv6ProtocolNumber, - flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastV6Addr, - expectedDestAddr: multicastV6Addr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(test.proto) - - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - if test.flow.isMulticast() { - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) - } - } - - c.ep.SocketOptions().SetReceivePacketInfo(true) - - testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ - NIC: 1, - LocalAddr: test.expectedLocalAddr, - DestinationAddr: test.expectedDestAddr, - })) - - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { - t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) - } - }) - } -} - -func TestReadRecvOriginalDstAddr(t *testing.T) { - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedOriginalDstAddr tcpip.FullAddress - }{ - { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, - }, - { - name: "IPv4 multicast", - proto: header.IPv4ProtocolNumber, - flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, - }, - { - name: "IPv4 broadcast", - proto: header.IPv4ProtocolNumber, - flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, - }, - { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, - }, - { - name: "IPv6 multicast", - proto: header.IPv6ProtocolNumber, - flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(test.proto) - - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%#v): %s", bindAddr, err) - } - - if test.flow.isMulticast() { - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) - } - } - - c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) - - testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) - - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { - t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) - } - }) - } -} - -func TestWriteIncrementsPacketsSent(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) - - var want uint64 = 2 - if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { - c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) - } -} - -func TestNoChecksum(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Disable the checksum generation. - c.ep.SocketOptions().SetNoChecksum(true) - // This option is effective on IPv4 only. - testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) - - // Enable the checksum generation. - c.ep.SocketOptions().SetNoChecksum(false) - testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) - }) - } -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface -} - -func (*testInterface) ID() tcpip.NICID { - return 0 -} - -func (*testInterface) Enabled() bool { - return true -} - -func TestTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const multicastTTL = 42 - if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { - c.t.Fatalf("SetSockOptInt failed: %s", err) - } - - var wantTTL uint8 - if flow.isMulticast() { - wantTTL = multicastTTL - } else { - var p stack.NetworkProtocolFactory - var n tcpip.NetworkProtocolNumber - if flow.isV4() { - p = ipv4.NewProtocol - n = ipv4.ProtocolNumber - } else { - p = ipv6.NewProtocol - n = ipv6.ProtocolNumber - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{p}, - }) - ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) - wantTTL = ep.DefaultTTL() - ep.Close() - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } -} - -func TestSetTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { - c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } - }) - } -} - -func TestSetTOS(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = testTOS - v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) - } - - if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { - c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) - } - - v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) - } - - if v != tos { - c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestSetTClass(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tClass = testTOS - v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) - } - - if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { - c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) - } - - v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) - } - - if v != tClass { - c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) - } - - // The header getter for TClass is called TOS, so use that checker. - testWrite(c, flow, checker.TOS(tClass, 0)) - }) - } -} - -func TestReceiveTosTClass(t *testing.T) { - const RcvTOSOpt = "ReceiveTosOption" - const RcvTClassOpt = "ReceiveTClassOption" - - testCases := []struct { - name string - tests []testFlow - }{ - {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, - {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, - } - for _, testCase := range testCases { - for _, flow := range testCase.tests { - t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - name := testCase.name - - var optionGetter func() bool - var optionSetter func(bool) - switch name { - case RcvTOSOpt: - optionGetter = c.ep.SocketOptions().GetReceiveTOS - optionSetter = c.ep.SocketOptions().SetReceiveTOS - case RcvTClassOpt: - optionGetter = c.ep.SocketOptions().GetReceiveTClass - optionSetter = c.ep.SocketOptions().SetReceiveTClass - default: - t.Fatalf("unkown test variant: %s", name) - } - - // Verify that setting and reading the option works. - v := optionGetter() - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) - } - - want := true - optionSetter(want) - - got := optionGetter() - if got != want { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) - } - - // Verify that the correct received TOS or TClass is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - switch name { - case RcvTClassOpt: - testRead(c, flow, checker.ReceiveTClass(testTOS)) - case RcvTOSOpt: - testRead(c, flow, checker.ReceiveTOS(testTOS)) - default: - t.Fatalf("unknown test variant: %s", name) - } - }) - } - } -} - -func TestMulticastInterfaceOption(t *testing.T) { - for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - h := flow.header4Tuple(outgoing) - mcastAddr := h.dstAddr.Addr - localIfAddr := h.srcAddr.Addr - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: flow.mapAddrIfApplicable(mcastAddr), - Port: stackPort, - } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - } - - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err) - } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant { - c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) - } -} - -// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV4UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true, will result in a payload large enough - // so that the final generated IPv4 packet is larger than - // header.IPv4MinimumProcessableDatagramSize. - largePayload bool - // badChecksum if true, will set an invalid checksum in the - // header. - badChecksum bool - }{ - {unicastV4, true, false, false}, - {unicastV4, true, true, false}, - {unicastV4, false, false, true}, - {unicastV4, false, true, true}, - {multicastV4, false, false, false}, - {multicastV4, false, true, false}, - {broadcast, false, false, false}, - {broadcast, false, true, false}, - } - checksumErrors := uint64(0) - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(576) - } - c.injectPacket(tc.flow, payload, tc.badChecksum) - if tc.badChecksum { - checksumErrors++ - if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { - t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - } - if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - - // We need to compare the included data part of the UDP packet that is in - // the ICMP packet with the matching original data. - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantLen := len(payload) - if tc.largePayload { - // To work out the data size we need to simulate what the sender would - // have done. The wanted size is the total available minus the sum of - // the headers in the UDP AND ICMP packets, given that we know the test - // had only a minimal IP header but the ICMP sender will have allowed - // for a maximally sized packet header. - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - } - - // In the case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - }) - } -} - -// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV6UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true will result in a payload large enough to - // create an IPv6 packet > header.IPv6MinimumMTU bytes. - largePayload bool - // badChecksum if true, will set an invalid checksum in the - // header. - badChecksum bool - }{ - {unicastV6, true, false, false}, - {unicastV6, true, true, false}, - {unicastV6, false, false, true}, - {unicastV6, false, true, true}, - {multicastV6, false, false, false}, - {multicastV6, false, true, false}, - } - checksumErrors := uint64(0) - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(1280) - } - c.injectPacket(tc.flow, payload, tc.badChecksum) - if tc.badChecksum { - checksumErrors++ - if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { - t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - } - if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - }) - } -} - -// TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats are incremented. -func TestIncrementMalformedPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - - // Invalidate the UDP header length field. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetLength(u.Length() + 1) - - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } -} - -// TestShortHeader verifies that when a packet with a too-short UDP header is -// received, the malformed received global stat gets incremented. -func TestShortHeader(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - h := unicastV6.header4Tuple(incoming) - - // Allocate a buffer for an IPv6 and too-short UDP header. - const udpSize = header.UDPMinimumSize - 1 - buf := buffer.NewView(header.IPv6MinimumSize + udpSize) - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) - udpHdr.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: header.UDPMinimumSize, - }) - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) - // Copy all but the last byte of the UDP header into the packet. - copy(buf[header.IPv6MinimumSize:], udpHdr) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) - } -} - -// TestBadChecksumErrors verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestBadChecksumErrors(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV6} { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } - } -} - -// TestPayloadModifiedV4 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestPayloadModifiedV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be - // incorrect. - buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestPayloadModifiedV6 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestPayloadModifiedV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be - // incorrect. - buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestChecksumZeroV4 verifies if the checksum value is zero, global and -// endpoint states are *not* incremented (UDP checksum is optional on IPv4). -func TestChecksumZeroV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) - // Set the checksum field in the UDP header to zero. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 0 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestChecksumZeroV6 verifies if the checksum value is zero, global and -// endpoint states are incremented (UDP checksum is *not* optional on IPv6). -func TestChecksumZeroV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Set the checksum field in the UDP header to zero. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestShutdownRead verifies endpoint read shutdown and error -// stats increment on packet receive. -func TestShutdownRead(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - testFailingRead(c, unicastV6, true /* expectReadError */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { - t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) - } -} - -// TestShutdownWrite verifies endpoint write shutdown and error -// stats increment on packet write. -func TestShutdownWrite(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) -} - -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err.(type) { - case nil: - want.PacketsSent.IncrementBy(incr) - case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: - want.WriteErrors.InvalidArgs.IncrementBy(incr) - case *tcpip.ErrClosedForSend: - want.WriteErrors.WriteClosed.IncrementBy(incr) - case *tcpip.ErrInvalidEndpointState: - want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: - want.SendErrors.NoRoute.IncrementBy(incr) - default: - want.SendErrors.SendToNetworkFailed.IncrementBy(incr) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - case *tcpip.ErrClosedForReceive: - want.ReadErrors.ReadClosed.IncrementBy(incr) - default: - c.t.Errorf("Endpoint error missing stats update err %v", err) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func TestOutgoingSubnetBroadcast(t *testing.T) { - const nicID1 = 1 - - ipv4Addr := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 24, - } - ipv4Subnet := ipv4Addr.Subnet() - ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") - ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 31, - } - ipv4Subnet31 := ipv4AddrPrefix31.Subnet() - ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() - ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 32, - } - ipv4Subnet32 := ipv4AddrPrefix32.Subnet() - ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() - ipv6Addr := tcpip.AddressWithPrefix{ - Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - PrefixLen: 64, - } - ipv6Subnet := ipv6Addr.Subnet() - ipv6SubnetBcast := ipv6Subnet.Broadcast() - remNetAddr := tcpip.AddressWithPrefix{ - Address: "\x64\x0a\x7b\x18", - PrefixLen: 24, - } - remNetSubnet := remNetAddr.Subnet() - remNetSubnetBcast := remNetSubnet.Broadcast() - - tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - requiresBroadcastOpt bool - }{ - { - name: "IPv4 Broadcast to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv4SubnetBcast, - requiresBroadcastOpt: true, - }, - { - name: "IPv4 Broadcast to local /31 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix31, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet31, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet31Bcast, - requiresBroadcastOpt: false, - }, - { - name: "IPv4 Broadcast to local /32 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix32, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet32, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet32Bcast, - requiresBroadcastOpt: false, - }, - // IPv6 has no notion of a broadcast. - { - name: "IPv6 'Broadcast' to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: ipv6Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv6Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv6SubnetBcast, - requiresBroadcastOpt: false, - }, - { - name: "IPv4 Broadcast to remote subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: remNetSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - // TODO(gvisor.dev/issue/3938): Once we support marking a route as - // broadcast, this test should require the broadcast option to be set. - requiresBroadcastOpt: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) - } - - s.SetRouteTable(test.routes) - - var netProto tcpip.NetworkProtocolNumber - switch l := len(test.remoteAddr); l { - case header.IPv4AddressSize: - netProto = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - netProto = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) - } - defer ep.Close() - - var r bytes.Reader - data := []byte{1, 2, 3, 4} - to := tcpip.FullAddress{ - Addr: test.remoteAddr, - Port: 80, - } - opts := tcpip.WriteOptions{To: &to} - expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { - if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { - return nil - } - return &tcpip.ErrBroadcastDisabled{} - } - if !test.requiresBroadcastOpt { - expectedErrWithoutBcastOpt = nil - } - - r.Reset(data) - { - n, err := ep.Write(&r, opts) - if expectedErrWithoutBcastOpt != nil { - if want := expectedErrWithoutBcastOpt(err); want != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) - } - } else if err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - } - - ep.SocketOptions().SetBroadcast(true) - - r.Reset(data) - if n, err := ep.Write(&r, opts); err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - - ep.SocketOptions().SetBroadcast(false) - - r.Reset(data) - { - n, err := ep.Write(&r, opts) - if expectedErrWithoutBcastOpt != nil { - if want := expectedErrWithoutBcastOpt(err); want != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) - } - } else if err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - } - }) - } -} |