summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/forwarder_test.go8
-rw-r--r--pkg/tcpip/stack/nic.go6
-rw-r--r--pkg/tcpip/stack/stack_test.go112
-rw-r--r--pkg/tcpip/stack/transport_test.go4
4 files changed, 79 insertions, 51 deletions
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 321b7524d..5a04590d5 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -473,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.First()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -517,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.First()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -564,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.First()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -619,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.First()
if b[0] < 8 {
t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index cd9202aed..e46bd86c6 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1246,10 +1246,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
- pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
- pkt.Data.RemoveFirst()
+ // TODO(b/143425874): Decrease the TTL field in forwarded packets.
+ // pkt.Header should have enough capacity to hold the link's headers.
+ pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()))
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e15db40fb..9515426d6 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -2240,56 +2240,84 @@ func TestNICStats(t *testing.T) {
}
func TestNICForwarding(t *testing.T) {
- // Create a stack with the fake network protocol, two NICs, each with
- // an address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- s.SetForwarding(true)
+ const nicID1 = 1
+ const nicID2 = 2
+ const dstAddr = tcpip.Address("\x03")
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC #1 failed:", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ tests := []struct {
+ name string
+ headerLen uint16
+ }{
+ {
+ name: "Zero header length",
+ },
+ {
+ name: "Non-zero header length",
+ headerLen: 16,
+ },
}
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC #2 failed:", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ s.SetForwarding(true)
- // Route all packets to address 3 to NIC 2.
- {
- subnet, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
- }
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
+ }
- // Send a packet to address 3.
- buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ ep2 := channelLinkWithHeaderLength{
+ Endpoint: channel.New(10, defaultMTU, ""),
+ headerLength: test.headerLen,
+ }
+ if err := s.CreateNIC(nicID2, &ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
+ }
- if _, ok := ep2.Read(); !ok {
- t.Fatal("Packet not forwarded")
- }
+ // Route all packets to dstAddr to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
+ }
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
+ // Send a packet to dstAddr.
+ buf := buffer.NewView(30)
+ buf[0] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
- if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the link's MaxHeaderLength is honoured.
+ if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
+ t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
+ }
+
+ // Test that forwarding increments Tx stats correctly.
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5d1da2f8b..3609a25b6 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -641,10 +641,10 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- if dst := p.Pkt.Header.View()[0]; dst != 3 {
+ if dst := p.Pkt.Data.First()[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := p.Pkt.Header.View()[1]; src != 1 {
+ if src := p.Pkt.Data.First()[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}