summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-09-28 16:22:09 -0700
committergVisor bot <gvisor-bot@google.com>2020-09-28 16:24:04 -0700
commita5acc0616c9552c7252e3f133f9ad4648422cc5f (patch)
treec62a2f05c8f8aee20671c39862ffe6721df35332 /pkg/tcpip/tests
parenta0e0ba690f3f4946890010e12084db7f081d5bc7 (diff)
Support creating protocol instances with Stack ref
Network or transport protocols may want to reach the stack. Support this by letting the stack create the protocol instances so it can pass a reference to itself at protocol creation time. Note, protocols do not yet use the stack in this CL but later CLs will make use of the stack from protocols. PiperOrigin-RevId: 334260210
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go6
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go22
2 files changed, 13 insertions, 15 deletions
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index fecbe7ba7..f35dcc084 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -120,8 +120,8 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
@@ -203,7 +203,7 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
otherAddr := tcpip.Address(addrBytes)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 659acbc7a..72d86b5ab 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -140,11 +140,9 @@ func TestPingMulticastBroadcast(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ipv4Proto := ipv4.NewProtocol()
- ipv6Proto := ipv6.NewProtocol()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
})
// We only expect a single packet in response to our ICMP Echo Request.
e := channel.New(1, defaultMTU, "")
@@ -176,18 +174,18 @@ func TestPingMulticastBroadcast(t *testing.T) {
var rxICMP func(*channel.Endpoint, tcpip.Address)
var expectedSrc tcpip.Address
var expectedDst tcpip.Address
- var proto stack.NetworkProtocol
+ var protoNum tcpip.NetworkProtocolNumber
switch l := len(test.dstAddr); l {
case header.IPv4AddressSize:
rxICMP = rxIPv4ICMP
expectedSrc = ipv4Addr.Address
expectedDst = remoteIPv4Addr
- proto = ipv4Proto
+ protoNum = header.IPv4ProtocolNumber
case header.IPv6AddressSize:
rxICMP = rxIPv6ICMP
expectedSrc = ipv6Addr.Address
expectedDst = remoteIPv6Addr
- proto = ipv6Proto
+ protoNum = header.IPv6ProtocolNumber
default:
t.Fatalf("got unexpected address length = %d bytes", l)
}
@@ -205,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
}
- src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View())
+ src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View())
if src != expectedSrc {
t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
}
@@ -380,8 +378,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID, e); err != nil {
@@ -466,8 +464,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)