diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2020-09-28 16:22:09 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-09-28 16:24:04 -0700 |
commit | a5acc0616c9552c7252e3f133f9ad4648422cc5f (patch) | |
tree | c62a2f05c8f8aee20671c39862ffe6721df35332 /pkg/tcpip/tests | |
parent | a0e0ba690f3f4946890010e12084db7f081d5bc7 (diff) |
Support creating protocol instances with Stack ref
Network or transport protocols may want to reach the stack. Support this
by letting the stack create the protocol instances so it can pass a
reference to itself at protocol creation time.
Note, protocols do not yet use the stack in this CL but later CLs will
make use of the stack from protocols.
PiperOrigin-RevId: 334260210
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 22 |
2 files changed, 13 insertions, 15 deletions
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index fecbe7ba7..f35dcc084 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -120,8 +120,8 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -203,7 +203,7 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { otherAddr := tcpip.Address(addrBytes) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 659acbc7a..72d86b5ab 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -140,11 +140,9 @@ func TestPingMulticastBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ipv4Proto := ipv4.NewProtocol() - ipv6Proto := ipv6.NewProtocol() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, }) // We only expect a single packet in response to our ICMP Echo Request. e := channel.New(1, defaultMTU, "") @@ -176,18 +174,18 @@ func TestPingMulticastBroadcast(t *testing.T) { var rxICMP func(*channel.Endpoint, tcpip.Address) var expectedSrc tcpip.Address var expectedDst tcpip.Address - var proto stack.NetworkProtocol + var protoNum tcpip.NetworkProtocolNumber switch l := len(test.dstAddr); l { case header.IPv4AddressSize: rxICMP = rxIPv4ICMP expectedSrc = ipv4Addr.Address expectedDst = remoteIPv4Addr - proto = ipv4Proto + protoNum = header.IPv4ProtocolNumber case header.IPv6AddressSize: rxICMP = rxIPv6ICMP expectedSrc = ipv6Addr.Address expectedDst = remoteIPv6Addr - proto = ipv6Proto + protoNum = header.IPv6ProtocolNumber default: t.Fatalf("got unexpected address length = %d bytes", l) } @@ -205,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -380,8 +378,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID, e); err != nil { @@ -466,8 +464,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) |