diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/forwarder_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 112 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 4 |
4 files changed, 79 insertions, 51 deletions
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 321b7524d..5a04590d5 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -473,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index cd9202aed..e46bd86c6 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1246,10 +1246,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) - pkt.Data.RemoveFirst() + // TODO(b/143425874): Decrease the TTL field in forwarded packets. + // pkt.Header should have enough capacity to hold the link's headers. + pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength())) if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e15db40fb..9515426d6 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2240,56 +2240,84 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) + const nicID1 = 1 + const nicID2 = 2 + const dstAddr = tcpip.Address("\x03") - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + tests := []struct { + name string + headerLen uint16 + }{ + { + name: "Zero header length", + }, + { + name: "Non-zero header length", + headerLen: 16, + }, } - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(true) - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) - } + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) + } - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + ep2 := channelLinkWithHeaderLength{ + Endpoint: channel.New(10, defaultMTU, ""), + headerLength: test.headerLen, + } + if err := s.CreateNIC(nicID2, &ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) + } - if _, ok := ep2.Read(); !ok { - t.Fatal("Packet not forwarded") - } + // Route all packets to dstAddr to NIC 2. + { + subnet, err := tcpip.NewSubnet(dstAddr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) + } - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } + // Send a packet to dstAddr. + buf := buffer.NewView(30) + buf[0] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not forwarded") + } + + // Test that the link's MaxHeaderLength is honoured. + if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { + t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + } + + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } + + if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + }) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5d1da2f8b..3609a25b6 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -641,10 +641,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + if dst := p.Pkt.Data.First()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := p.Pkt.Data.First()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } |