summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/stack_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/stack_test.go')
-rw-r--r--pkg/tcpip/stack/stack_test.go112
1 files changed, 42 insertions, 70 deletions
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9515426d6..e15db40fb 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -2240,84 +2240,56 @@ func TestNICStats(t *testing.T) {
}
func TestNICForwarding(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const dstAddr = tcpip.Address("\x03")
+ // Create a stack with the fake network protocol, two NICs, each with
+ // an address.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ s.SetForwarding(true)
- tests := []struct {
- name string
- headerLen uint16
- }{
- {
- name: "Zero header length",
- },
- {
- name: "Non-zero header length",
- headerLen: 16,
- },
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC #1 failed:", err)
+ }
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress #1 failed:", err)
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- s.SetForwarding(true)
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
- }
-
- ep2 := channelLinkWithHeaderLength{
- Endpoint: channel.New(10, defaultMTU, ""),
- headerLength: test.headerLen,
- }
- if err := s.CreateNIC(nicID2, &ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
- }
-
- // Route all packets to dstAddr to NIC 2.
- {
- subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
- }
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatal("CreateNIC #2 failed:", err)
+ }
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress #2 failed:", err)
+ }
- // Send a packet to dstAddr.
- buf := buffer.NewView(30)
- buf[0] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ // Route all packets to address 3 to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
+ }
- pkt, ok := ep2.Read()
- if !ok {
- t.Fatal("packet not forwarded")
- }
+ // Send a packet to address 3.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
- // Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
- t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
- }
+ if _, ok := ep2.Read(); !ok {
+ t.Fatal("Packet not forwarded")
+ }
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
+ // Test that forwarding increments Tx stats correctly.
+ if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ }
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
- })
+ if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}