summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD60
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go266
-rwxr-xr-xpkg/tcpip/stack/stack_state_autogen.go59
-rw-r--r--pkg/tcpip/stack/stack_test.go1151
-rw-r--r--pkg/tcpip/stack/transport_test.go538
5 files changed, 59 insertions, 2015 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
deleted file mode 100644
index 28d11c797..000000000
--- a/pkg/tcpip/stack/BUILD
+++ /dev/null
@@ -1,60 +0,0 @@
-package(licenses = ["notice"])
-
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-
-go_library(
- name = "stack",
- srcs = [
- "linkaddrcache.go",
- "nic.go",
- "registration.go",
- "route.go",
- "stack.go",
- "stack_global_state.go",
- "transport_demuxer.go",
- ],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/stack",
- visibility = [
- "//visibility:public",
- ],
- deps = [
- "//pkg/ilist",
- "//pkg/sleep",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/hash/jenkins",
- "//pkg/tcpip/header",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/seqnum",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "stack_x_test",
- size = "small",
- srcs = [
- "stack_test.go",
- "transport_test.go",
- ],
- deps = [
- ":stack",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "stack_test",
- size = "small",
- srcs = ["linkaddrcache_test.go"],
- embed = [":stack"],
- deps = [
- "//pkg/sleep",
- "//pkg/tcpip",
- ],
-)
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
deleted file mode 100644
index 924f4d240..000000000
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ /dev/null
@@ -1,266 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "fmt"
- "sync"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sleep"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-type testaddr struct {
- addr tcpip.FullAddress
- linkAddr tcpip.LinkAddress
-}
-
-var testaddrs []testaddr
-
-type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
-}
-
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- go func() {
- if r.delay > 0 {
- time.Sleep(r.delay)
- }
- r.fakeRequest(addr)
- }()
- return nil
-}
-
-func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testaddrs {
- if ta.addr.Addr == addr {
- r.cache.add(ta.addr, ta.linkAddr)
- break
- }
- }
-}
-
-func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "broadcast" {
- return "mac_broadcast", true
- }
- return "", false
-}
-
-func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return 1
-}
-
-func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
- for {
- if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- return got, err
- }
- s.Fetch(true)
- }
-}
-
-func init() {
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- testaddrs = append(testaddrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
-}
-
-func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testaddrs) - 1; i >= 0; i-- {
- e := testaddrs[i]
- c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
- }
- }
- // Expect to find at least half of the most recent entries.
- for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testaddrs[i]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
- }
- }
- // The earliest entries should no longer be in the cache.
- for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
- e := testaddrs[i]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
- }
- }
-}
-
-func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
-
- var wg sync.WaitGroup
- for r := 0; r < 16; r++ {
- wg.Add(1)
- go func() {
- for _, e := range testaddrs {
- c.add(e.addr, e.linkAddr)
- c.get(e.addr, nil, "", nil, nil) // make work for gotsan
- }
- wg.Done()
- }()
- }
- wg.Wait()
-
- // All goroutines add in the same order and add more values than
- // can fit in the cache, so our eviction strategy requires that
- // the last entry be present and the first be missing.
- e := testaddrs[len(testaddrs)-1]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
-
- e = testaddrs[0]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-}
-
-func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testaddrs[0]
- c.add(e.addr, e.linkAddr)
- time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-}
-
-func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testaddrs[0]
- l2 := e.linkAddr + "2"
- c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
-
- c.add(e.addr, l2)
- got, _, err = c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != l2 {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
- }
-}
-
-func TestCacheResolution(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
- linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testaddrs {
- got, err := getBlocking(c, ta.addr, linkRes)
- if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
- }
- if got != ta.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
- }
- }
-
- // Check that after resolved, address stays in the cache and never returns WouldBlock.
- for i := 0; i < 10; i++ {
- e := testaddrs[len(testaddrs)-1]
- got, _, err := c.get(e.addr, linkRes, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
- }
-}
-
-func TestCacheResolutionFailed(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
- linkRes := &testLinkAddressResolver{cache: c}
-
- // First, sanity check that resolution is working...
- e := testaddrs[0]
- got, err := getBlocking(c, e.addr, linkRes)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
-
- e.addr.Addr += "2"
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-}
-
-func TestCacheResolutionTimeout(t *testing.T) {
- resolverDelay := 500 * time.Millisecond
- expiration := resolverDelay / 10
- c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
- linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
-
- e := testaddrs[0]
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-}
-
-// TestStaticResolution checks that static link addresses are resolved immediately and don't
-// send resolution requests.
-func TestStaticResolution(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
- linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
-
- addr := tcpip.Address("broadcast")
- want := tcpip.LinkAddress("mac_broadcast")
- got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
- }
- if got != want {
- t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
- }
-}
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
new file mode 100755
index 000000000..2a15cfa0a
--- /dev/null
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -0,0 +1,59 @@
+// automatically generated by stateify.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *TransportEndpointID) beforeSave() {}
+func (x *TransportEndpointID) save(m state.Map) {
+ x.beforeSave()
+ m.Save("LocalPort", &x.LocalPort)
+ m.Save("LocalAddress", &x.LocalAddress)
+ m.Save("RemotePort", &x.RemotePort)
+ m.Save("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *TransportEndpointID) afterLoad() {}
+func (x *TransportEndpointID) load(m state.Map) {
+ m.Load("LocalPort", &x.LocalPort)
+ m.Load("LocalAddress", &x.LocalAddress)
+ m.Load("RemotePort", &x.RemotePort)
+ m.Load("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *GSOType) save(m state.Map) {
+ m.SaveValue("", (int)(*x))
+}
+
+func (x *GSOType) load(m state.Map) {
+ m.LoadValue("", new(int), func(y interface{}) { *x = (GSOType)(y.(int)) })
+}
+
+func (x *GSO) beforeSave() {}
+func (x *GSO) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("NeedsCsum", &x.NeedsCsum)
+ m.Save("CsumOffset", &x.CsumOffset)
+ m.Save("MSS", &x.MSS)
+ m.Save("L3HdrLen", &x.L3HdrLen)
+ m.Save("MaxSize", &x.MaxSize)
+}
+
+func (x *GSO) afterLoad() {}
+func (x *GSO) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("NeedsCsum", &x.NeedsCsum)
+ m.Load("CsumOffset", &x.CsumOffset)
+ m.Load("MSS", &x.MSS)
+ m.Load("L3HdrLen", &x.L3HdrLen)
+ m.Load("MaxSize", &x.MaxSize)
+}
+
+func init() {
+ state.Register("stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load})
+ state.Register("stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load})
+ state.Register("stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load})
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
deleted file mode 100644
index 959071dbe..000000000
--- a/pkg/tcpip/stack/stack_test.go
+++ /dev/null
@@ -1,1151 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package stack_test contains tests for the stack. It is in its own package so
-// that the tests can also validate that all definitions needed to implement
-// transport and network protocols are properly exported by the stack package.
-package stack_test
-
-import (
- "bytes"
- "fmt"
- "math"
- "strings"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
- fakeNetHeaderLen = 12
-
- // fakeControlProtocol is used for control packets that represent
- // destination port unreachable.
- fakeControlProtocol tcpip.TransportProtocolNumber = 2
-
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65536
-)
-
-// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
-// received packets; the counts of all endpoints are aggregated in the protocol
-// descriptor.
-//
-// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only
-// use the first three: destination address, source address, and transport
-// protocol. They're all one byte fields to simplify parsing.
-type fakeNetworkEndpoint struct {
- nicid tcpip.NICID
- id stack.NetworkEndpointID
- proto *fakeNetworkProtocol
- dispatcher stack.TransportDispatcher
- linkEP stack.LinkEndpoint
-}
-
-func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
-}
-
-func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicid
-}
-
-func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
- return 123
-}
-
-func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
- return &f.id
-}
-
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- // Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
-
- // Consume the network header.
- b := vv.First()
- vv.TrimFront(fakeNetHeaderLen)
-
- // Handle control packets.
- if b[2] == uint8(fakeControlProtocol) {
- nb := vv.First()
- if len(nb) < fakeNetHeaderLen {
- return
- }
-
- vv.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
- return
- }
-
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv)
-}
-
-func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
-}
-
-func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
- return 0
-}
-
-func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.linkEP.Capabilities()
-}
-
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
- // Increment the sent packet count in the protocol descriptor.
- f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
-
- // Add the protocol's header to the packet and send it to the link
- // endpoint.
- b := hdr.Prepend(fakeNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(protocol)
-
- if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
- f.HandlePacket(r, vv)
- }
- if loop&stack.PacketOut == 0 {
- return nil
- }
-
- return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
-}
-
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func (*fakeNetworkEndpoint) Close() {}
-
-type fakeNetGoodOption bool
-
-type fakeNetBadOption bool
-
-type fakeNetInvalidValueOption int
-
-type fakeNetOptions struct {
- good bool
-}
-
-// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
-// number of packets sent and received via endpoints of this protocol. The index
-// where packets are added is given by the packet's destination address MOD 10.
-type fakeNetworkProtocol struct {
- packetCount [10]int
- sendPacketCount [10]int
- opts fakeNetOptions
-}
-
-func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
- return fakeNetNumber
-}
-
-func (f *fakeNetworkProtocol) MinimumPacketSize() int {
- return fakeNetHeaderLen
-}
-
-func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
-}
-
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
- return &fakeNetworkEndpoint{
- nicid: nicid,
- id: stack.NetworkEndpointID{addr},
- proto: f,
- dispatcher: dispatcher,
- linkEP: linkEP,
- }, nil
-}
-
-func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case fakeNetGoodOption:
- f.opts.good = bool(v)
- return nil
- case fakeNetInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case *fakeNetGoodOption:
- *v = fakeNetGoodOption(f.opts.good)
- return nil
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func TestNetworkReceive(t *testing.T) {
- // Create a stack with the fake network protocol, one nic, and two
- // addresses attached to it: 1 & 2.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Make sure packet with wrong address is not delivered.
- buf[0] = 3
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
- if fakeNet.packetCount[2] != 0 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
- }
-
- // Make sure packet is delivered to first endpoint.
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 0 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
- }
-
- // Make sure packet is delivered to second endpoint.
- buf[0] = 2
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-
- // Make sure packet is not delivered if protocol number is wrong.
- linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-
- // Make sure packet that is too small is dropped.
- buf.CapLength(2)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-}
-
-func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
- r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
- }
- defer r.Release()
-
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Errorf("WritePacket failed: %v", err)
- }
-}
-
-func TestNetworkSend(t *testing.T) {
- // Create a stack with the fake network protocol, one nic, and one
- // address: 1. The route table sends all packets through the only
- // existing nic.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("NewNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- // Make sure that the link-layer endpoint received the outbound packet.
- sendTo(t, s, "\x03", nil)
- if c := linkEP.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
-}
-
-func TestNetworkSendMultiRoute(t *testing.T) {
- // Create a stack with the fake network protocol, two nics, and two
- // addresses per nic, the first nic has odd address, the second one has
- // even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
-
- // Send a packet to an odd destination.
- sendTo(t, s, "\x05", nil)
-
- if c := linkEP1.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
-
- // Send a packet to an even destination.
- sendTo(t, s, "\x06", nil)
-
- if c := linkEP2.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
-}
-
-func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
- r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
- }
-
- defer r.Release()
-
- if r.LocalAddress != expectedSrcAddr {
- t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress)
- }
-
- if r.RemoteAddress != dstAddr {
- t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress)
- }
-}
-
-func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
- _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err)
- }
-}
-
-func TestRoutes(t *testing.T) {
- // Create a stack with the fake network protocol, two nics, and two
- // addresses per nic, the first nic has odd address, the second one has
- // even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id1, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- id2, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
-
- // Test routes to odd address.
- testRoute(t, s, 0, "", "\x05", "\x01")
- testRoute(t, s, 0, "\x01", "\x05", "\x01")
- testRoute(t, s, 1, "\x01", "\x05", "\x01")
- testRoute(t, s, 0, "\x03", "\x05", "\x03")
- testRoute(t, s, 1, "\x03", "\x05", "\x03")
-
- // Test routes to even address.
- testRoute(t, s, 0, "", "\x06", "\x02")
- testRoute(t, s, 0, "\x02", "\x06", "\x02")
- testRoute(t, s, 2, "\x02", "\x06", "\x02")
- testRoute(t, s, 0, "\x04", "\x06", "\x04")
- testRoute(t, s, 2, "\x04", "\x06", "\x04")
-
- // Try to send to odd numbered address from even numbered ones, then
- // vice-versa.
- testNoRoute(t, s, 0, "\x02", "\x05")
- testNoRoute(t, s, 2, "\x02", "\x05")
- testNoRoute(t, s, 0, "\x04", "\x05")
- testNoRoute(t, s, 2, "\x04", "\x05")
-
- testNoRoute(t, s, 0, "\x01", "\x06")
- testNoRoute(t, s, 1, "\x01", "\x06")
- testNoRoute(t, s, 0, "\x03", "\x06")
- testNoRoute(t, s, 1, "\x03", "\x06")
-}
-
-func TestAddressRemoval(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-
- // Remove the address, then check that packet doesn't get delivered
- // anymore.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
- }
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-
- // Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress failed: %v", err)
- }
-}
-
-func TestDelayedRemovalDueToRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-
- // Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
- }
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 2 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
- }
-
- // Remove the address, then check that packet is still deliverable
- // because the route is keeping the address alive.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
- }
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
-
- // Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress failed: %v", err)
- }
-
- // Release the route, then check that packet is not deliverable anymore.
- r.Release()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
-}
-
-func TestPromiscuousMode(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it doesn't get delivered as we don't
- // have a matching endpoint.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
-
- // Set promiscuous mode, then check that packet is delivered.
- if err := s.SetPromiscuousMode(1, true); err != nil {
- t.Fatalf("SetPromiscuousMode failed: %v", err)
- }
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-
- // Check that we can't get a route as there is no local address.
- _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err)
- }
-
- // Set promiscuous mode to false, then check that packet can't be
- // delivered anymore.
- if err := s.SetPromiscuousMode(1, false); err != nil {
- t.Fatalf("SetPromiscuousMode failed: %v", err)
- }
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-}
-
-func TestAddressSpoofing(t *testing.T) {
- srcAddr := tcpip.Address("\x01")
- dstAddr := tcpip.Address("\x02")
-
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
- // With address spoofing disabled, FindRoute does not permit an address
- // that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err == nil {
- t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
- }
-
- // With address spoofing enabled, FindRoute permits any address to be used
- // as the source.
- if err := s.SetSpoofing(1, true); err != nil {
- t.Fatalf("SetSpoofing failed: %v", err)
- }
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
- }
- if r.LocalAddress != srcAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
- }
- if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
- }
-}
-
-func TestBroadcastNeedsNoRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- s.SetRouteTable([]tcpip.Route{})
-
- // If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err)
- }
- r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
- }
-
- if r.LocalAddress != header.IPv4Any {
- t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any)
- }
-
- if r.RemoteAddress != header.IPv4Broadcast {
- t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast)
- }
-
- // If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
- }
-}
-
-func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
- for _, tc := range []struct {
- name string
- routeNeeded bool
- address tcpip.Address
- }{
- // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255
- // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff
- {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"},
- {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"},
- {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"},
- {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"},
- {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"},
-
- // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112]
- {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
-
- // IPv6 link-local address starts with fe80::/10.
- {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
- {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"},
- {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
-
- // IPv6 addresses that are neither multicast nor link-local.
- {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
- {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- } {
- t.Run(tc.name, func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{})
-
- var anyAddr tcpip.Address
- if len(tc.address) == header.IPv4AddressSize {
- anyAddr = header.IPv4Any
- } else {
- anyAddr = header.IPv6Any
- }
-
- want := tcpip.ErrNetworkUnreachable
- if tc.routeNeeded {
- want = tcpip.ErrNoRoute
- }
-
- // If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
- }
-
- if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
- // Route table is empty but we need a route, this should cause an error.
- if err != tcpip.ErrNoRoute {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute)
- }
- } else {
- if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err)
- }
- if r.LocalAddress != anyAddr {
- t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, anyAddr)
- }
- if r.RemoteAddress != tc.address {
- t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, tc.address)
- }
- }
- // If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
- t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
- }
- })
- }
-}
-
-// Set the subnet, then check that packet is delivered.
-func TestSubnetAcceptsMatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- buf[0] = 1
- fakeNet.packetCount[1] = 0
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
- }
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
- }
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-}
-
-// Set destination outside the subnet, then check it doesn't get delivered.
-func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
-
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- buf[0] = 1
- fakeNet.packetCount[1] = 0
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
- }
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
- }
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
-}
-
-func TestNetworkOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
-
- // Try an unsupported network protocol.
- if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.NetworkProtocol)
- }{
- {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
- t.Helper()
- fakeNet := p.(*fakeNetworkProtocol)
- if fakeNet.opts.good != true {
- t.Fatalf("fakeNet.opts.good = false, want = true")
- }
- var v fakeNetGoodOption
- if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
- }
- }},
- {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
- }
- }
-}
-
-func TestSubnetAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- addr := tcpip.Address("\x01\x01\x01\x01")
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- subnet, err := tcpip.NewSubnet(addr, mask)
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
- }
-
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatalf("ContainsSubnet failed: %v", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
- }
-
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
- }
-
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatalf("ContainsSubnet failed: %v", err)
- } else if !contained {
- t.Fatal("got s.ContainsSubnet(...) = false, want = true")
- }
-
- if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatalf("RemoveSubnet failed: %v", err)
- }
-
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatalf("ContainsSubnet failed: %v", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
- }
-}
-
-func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
- for _, addrLen := range []int{4, 16} {
- t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) {
- for canBe := 0; canBe < 3; canBe++ {
- t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
- for never := 0; never < 3; never++ {
- t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- // Insert <canBe> primary and <never> never-primary addresses.
- // Each one will add a network endpoint to the NIC.
- primaryAddrAdded := make(map[tcpip.Address]tcpip.Subnet)
- for i := 0; i < canBe+never; i++ {
- var behavior stack.PrimaryEndpointBehavior
- if i < canBe {
- behavior = stack.CanBePrimaryEndpoint
- } else {
- behavior = stack.NeverPrimaryEndpoint
- }
- // Add an address and in case of a primary one also add a
- // subnet.
- address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
- if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions failed: %v", err)
- }
- if behavior == stack.CanBePrimaryEndpoint {
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(address)))
- subnet, err := tcpip.NewSubnet(address, mask)
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
- }
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
- }
- // Remember the address/subnet.
- primaryAddrAdded[address] = subnet
- }
- }
- // Check that GetMainNICAddress returns an address if at least
- // one primary address was added. In that case make sure the
- // address/subnet matches what we added.
- if len(primaryAddrAdded) == 0 {
- // No primary addresses present, expect an error.
- if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %v", err, tcpip.ErrNoLinkAddress)
- }
- } else {
- // At least one primary address was added, expect a valid
- // address and subnet.
- gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatalf("GetMainNICAddress failed: %v", err)
- }
- expectedSubnet, ok := primaryAddrAdded[gotAddress]
- if !ok {
- t.Fatalf("GetMainNICAddress: got address = %v, wanted any in {%v}", gotAddress, primaryAddrAdded)
- }
- if gotSubnet != expectedSubnet {
- t.Fatalf("GetMainNICAddress: got subnet = %v, wanted %v", gotSubnet, expectedSubnet)
- }
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- for _, tc := range []struct {
- name string
- address tcpip.Address
- }{
- {"IPv4", "\x01\x01\x01\x01"},
- {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
- } {
- t.Run(tc.name, func(t *testing.T) {
- address := tc.address
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(address)))
- subnet, err := tcpip.NewSubnet(address, mask)
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
- }
-
- // Check that we get the right initial address and subnet.
- if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress failed: %v", err)
- } else if gotAddress != address {
- t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address)
- } else if gotSubnet != subnet {
- t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet)
- }
-
- if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatalf("RemoveSubnet failed: %v", err)
- }
-
- if err := s.RemoveAddress(1, address); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
- }
-
- // Check that we get an error after removal.
- if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress)
- }
- })
- }
-}
-
-func TestNICStats(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
- // Route all packets for address \x01 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\xff", "\x00", 1},
- })
-
- // Send a packet to address 1.
- buf := buffer.NewView(30)
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
- if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
- }
-
- payload := buffer.NewView(10)
- // Write a packet out via the address for NIC 1
- sendTo(t, s, "\x01", payload)
- want := uint64(linkEP1.Drain())
- if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
- }
-
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestNICForwarding(t *testing.T) {
- // Create a stack with the fake network protocol, two NICs, each with
- // an address.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- s.SetForwarding(true)
-
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
- }
-
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
- }
-
- // Route all packets to address 3 to NIC 2.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- })
-
- // Send a packet to address 3.
- buf := buffer.NewView(30)
- buf[0] = 3
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
-
- select {
- case <-linkEP2.C:
- default:
- t.Fatal("Packet not forwarded")
- }
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
-}
-
-func init() {
- stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
- })
-}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
deleted file mode 100644
index b418db046..000000000
--- a/pkg/tcpip/stack/transport_test.go
+++ /dev/null
@@ -1,538 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeTransHeaderLen = 3
-)
-
-// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
-// received packets; the counts of all endpoints are aggregated in the protocol
-// descriptor.
-//
-// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
-// use it.
-type fakeTransportEndpoint struct {
- id stack.TransportEndpointID
- stack *stack.Stack
- netProto tcpip.NetworkProtocolNumber
- proto *fakeTransportProtocol
- peerAddr tcpip.Address
- route stack.Route
-
- // acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
-}
-
-func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto}
-}
-
-func (f *fakeTransportEndpoint) Close() {
- f.route.Release()
-}
-
-func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
- return mask
-}
-
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return buffer.View{}, tcpip.ControlMessages{}, nil
-}
-
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
- if len(f.route.RemoteAddress) == 0 {
- return 0, nil, tcpip.ErrNoRoute
- }
-
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
- v, err := p.Get(p.Size())
- if err != nil {
- return 0, nil, err
- }
- if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil {
- return 0, nil, err
- }
-
- return uintptr(len(v)), nil, nil
-}
-
-func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
-// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
- return -1, tcpip.ErrUnknownProtocolOption
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return nil
- }
- return tcpip.ErrInvalidEndpointState
-}
-
-func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- f.peerAddr = addr.Addr
-
- // Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- return tcpip.ErrNoRoute
- }
- defer r.Release()
-
- // Try to register so that we can start receiving packets.
- f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
- if err != nil {
- return err
- }
-
- f.route = r.Clone()
-
- return nil
-}
-
-func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
- return nil
-}
-
-func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
- return nil
-}
-
-func (*fakeTransportEndpoint) Reset() {
-}
-
-func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
- return nil
-}
-
-func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- if len(f.acceptQueue) == 0 {
- return nil, nil, nil
- }
- a := f.acceptQueue[0]
- f.acceptQueue = f.acceptQueue[1:]
- return &a, nil, nil
-}
-
-func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
- if err := f.stack.RegisterTransportEndpoint(
- a.NIC,
- []tcpip.NetworkProtocolNumber{fakeNetNumber},
- fakeTransNumber,
- stack.TransportEndpointID{LocalAddress: a.Addr},
- f,
- false,
- ); err != nil {
- return err
- }
- f.acceptQueue = []fakeTransportEndpoint{}
- return nil
-}
-
-func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, nil
-}
-
-func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, nil
-}
-
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
- // Increment the number of received packets.
- f.proto.packetCount++
- if f.acceptQueue != nil {
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- id: id,
- stack: f.stack,
- netProto: f.netProto,
- proto: f.proto,
- peerAddr: r.RemoteAddress,
- route: r.Clone(),
- })
- }
-}
-
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) {
- // Increment the number of received control packets.
- f.proto.controlCount++
-}
-
-func (f *fakeTransportEndpoint) State() uint32 {
- return 0
-}
-
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
-}
-
-type fakeTransportGoodOption bool
-
-type fakeTransportBadOption bool
-
-type fakeTransportInvalidValueOption int
-
-type fakeTransportProtocolOptions struct {
- good bool
-}
-
-// fakeTransportProtocol is a transport-layer protocol descriptor. It
-// aggregates the number of packets received via endpoints of this protocol.
-type fakeTransportProtocol struct {
- packetCount int
- controlCount int
- opts fakeTransportProtocolOptions
-}
-
-func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
- return fakeTransNumber
-}
-
-func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto), nil
-}
-
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return nil, tcpip.ErrUnknownProtocol
-}
-
-func (*fakeTransportProtocol) MinimumPacketSize() int {
- return fakeTransHeaderLen
-}
-
-func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
- return 0, 0, nil
-}
-
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
- return true
-}
-
-func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case fakeTransportGoodOption:
- f.opts.good = bool(v)
- return nil
- case fakeTransportInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case *fakeTransportGoodOption:
- *v = fakeTransportGoodOption(f.opts.good)
- return nil
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func TestTransportReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- // Create endpoint and connect to remote address.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
-
- // Create buffer that will hold the packet.
- buf := buffer.NewView(30)
-
- // Make sure packet with wrong protocol is not delivered.
- buf[0] = 1
- buf[2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeTrans.packetCount != 0 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
- }
-
- // Make sure packet from the wrong source is not delivered.
- buf[0] = 1
- buf[1] = 3
- buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeTrans.packetCount != 0 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
- }
-
- // Make sure packet is delivered.
- buf[0] = 1
- buf[1] = 2
- buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeTrans.packetCount != 1 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
- }
-}
-
-func TestTransportControlReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- // Create endpoint and connect to remote address.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
-
- // Create buffer that will hold the control packet.
- buf := buffer.NewView(2*fakeNetHeaderLen + 30)
-
- // Outer packet contains the control protocol number.
- buf[0] = 1
- buf[1] = 0xfe
- buf[2] = uint8(fakeControlProtocol)
-
- // Make sure packet with wrong protocol is not delivered.
- buf[fakeNetHeaderLen+0] = 0
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeTrans.controlCount != 0 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
- }
-
- // Make sure packet from the wrong source is not delivered.
- buf[fakeNetHeaderLen+0] = 3
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeTrans.controlCount != 0 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
- }
-
- // Make sure packet is delivered.
- buf[fakeNetHeaderLen+0] = 2
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeTrans.controlCount != 1 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
- }
-}
-
-func TestTransportSend(t *testing.T) {
- id, _ := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
-
- // Create endpoint and bind it.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- // Create buffer that will hold the payload.
- view := buffer.NewView(30)
- _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("write failed: %v", err)
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- if fakeNet.sendPacketCount[2] != 1 {
- t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1)
- }
-}
-
-func TestTransportOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
-
- // Try an unsupported transport protocol.
- if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.TransportProtocol)
- }{
- {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
- t.Helper()
- fakeTrans := p.(*fakeTransportProtocol)
- if fakeTrans.opts.good != true {
- t.Fatalf("fakeTrans.opts.good = false, want = true")
- }
- var v fakeTransportGoodOption
- if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
- }
-
- }},
- {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
- }
- }
-}
-
-func TestTransportForwarding(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- s.SetForwarding(true)
-
- // TODO(b/123449044): Change this to a channel NIC.
- id1 := loopback.New()
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
- }
-
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
- }
-
- // Route all packets to address 3 to NIC 2 and all packets to address
- // 1 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- {"\x01", "\xff", "\x00", 1},
- })
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Send a packet to address 1 from address 3.
- req := buffer.NewView(30)
- req[0] = 1
- req[1] = 3
- req[2] = byte(fakeTransNumber)
- linkEP2.Inject(fakeNetNumber, req.ToVectorisedView())
-
- aep, _, err := ep.Accept()
- if err != nil || aep == nil {
- t.Fatalf("Accept failed: %v, %v", aep, err)
- }
-
- resp := buffer.NewView(30)
- if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- var p channel.PacketInfo
- select {
- case p = <-linkEP2.C:
- default:
- t.Fatal("Response packet not forwarded")
- }
-
- if dst := p.Header[0]; dst != 3 {
- t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
- }
- if src := p.Header[1]; src != 1 {
- t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
- }
-}
-
-func init() {
- stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
- return &fakeTransportProtocol{}
- })
-}