summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/stack_test.go
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-11-12 17:30:31 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-12 17:33:21 -0800
commit1a972411b36b8ad2543d3ea614c92e60ccbdffab (patch)
tree43fd70a53b1ee47469ba3876920eaaee9863c813 /pkg/tcpip/stack/stack_test.go
parentae7ab0a330aaa1676d1fe066e3f5ac5fe805ec1c (diff)
Move packet handling to NetworkEndpoint
The NIC should not hold network-layer state or logic - network packet handling/forwarding should be performed at the network layer instead of the NIC. Fixes #4688 PiperOrigin-RevId: 342166985
Diffstat (limited to 'pkg/tcpip/stack/stack_test.go')
-rw-r--r--pkg/tcpip/stack/stack_test.go96
1 files changed, 10 insertions, 86 deletions
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dedfdd435..61db3164b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -112,7 +112,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
netHdr := pkt.NetworkHeader().View()
- f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
+
+ dst := tcpip.Address(netHdr[dstAddrOffset:][:1])
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ return
+ }
+ addressEndpoint.DecRef()
+
+ f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
@@ -159,9 +167,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.Clone()
- r.PopulatePacketInfo(pkt)
- f.HandlePacket(pkt)
+ f.HandlePacket(pkt.Clone())
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -2214,88 +2220,6 @@ func TestNICStats(t *testing.T) {
}
}
-func TestNICForwarding(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const dstAddr = tcpip.Address("\x03")
-
- tests := []struct {
- name string
- headerLen uint16
- }{
- {
- name: "Zero header length",
- },
- {
- name: "Non-zero header length",
- headerLen: 16,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
- }
-
- ep2 := channelLinkWithHeaderLength{
- Endpoint: channel.New(10, defaultMTU, ""),
- headerLength: test.headerLen,
- }
- if err := s.CreateNIC(nicID2, &ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
- }
-
- // Route all packets to dstAddr to NIC 2.
- {
- subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
- }
-
- // Send a packet to dstAddr.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- pkt, ok := ep2.Read()
- if !ok {
- t.Fatal("packet not forwarded")
- }
-
- // Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
- t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
- }
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
// TestNICContextPreservation tests that you can read out via stack.NICInfo the
// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
func TestNICContextPreservation(t *testing.T) {