From 5075d0342f51b3e44ae47fc0901a59a4d762c638 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 29 Sep 2020 02:04:11 -0700 Subject: Trim Network/Transport Endpoint/Protocol * Remove Capabilities and NICID methods from NetworkEndpoint. * Remove linkEP and stack parameters from NetworkProtocol.NewEndpoint. The LinkEndpoint can be fetched from the NetworkInterface. The stack is passed to the NetworkProtocol when it is created so the NetworkEndpoint can get it from its protocol. * Remove stack parameter from TransportProtocol.NewEndpoint. Like the NetworkProtocol/Endpoint, the stack is passed to the TransportProtocol when it is created. PiperOrigin-RevId: 334332721 --- pkg/tcpip/network/ip_test.go | 155 +++++++++++++++++++++++++++---------------- 1 file changed, 96 insertions(+), 59 deletions(-) (limited to 'pkg/tcpip/network/ip_test.go') diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 66450f896..56a56362e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -252,6 +252,8 @@ func buildDummyStack(t *testing.T) *stack.Stack { var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { + tester testObject + mu struct { sync.RWMutex disabled bool @@ -282,6 +284,10 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } +func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { + return &t.tester +} + func TestEnableWhenNICDisabled(t *testing.T) { tests := []struct { name string @@ -312,7 +318,7 @@ func TestEnableWhenNICDisabled(t *testing.T) { // We pass nil for all parameters except the NetworkInterface and Stack // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil, nil, nil, nil, s) + ep := p.NewEndpoint(&nic, nil, nil, nil) // The endpoint should initially be disabled, regardless the NIC's enabled // status. @@ -365,10 +371,15 @@ func TestEnableWhenNICDisabled(t *testing.T) { } func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil, &o, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() // Allocate and initialize the payload view. @@ -384,10 +395,10 @@ func TestIPv4Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIpv4Addr + nic.tester.dstAddr = remoteIpv4Addr + nic.tester.contents = payload r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) if err != nil { @@ -403,10 +414,15 @@ func TestIPv4Send(t *testing.T) { } func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -431,10 +447,10 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv4Addr + nic.tester.dstAddr = localIpv4Addr + nic.tester.contents = view[header.IPv4MinimumSize:totalLen] r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) if err != nil { @@ -447,8 +463,8 @@ func TestIPv4Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -478,10 +494,14 @@ func TestIPv4ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -528,26 +548,31 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv4Addr + nic.tester.dstAddr = localIpv4Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } } func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -590,10 +615,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv4Addr + nic.tester.dstAddr = localIpv4Addr + nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) if err != nil { @@ -608,8 +633,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + if nic.tester.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) } // Send second segment. @@ -620,16 +645,20 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } func TestIPv6Send(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, channel.New(0, 1280, ""), s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() if err := ep.Enable(); err != nil { @@ -649,10 +678,10 @@ func TestIPv6Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIpv6Addr + nic.tester.dstAddr = remoteIpv6Addr + nic.tester.contents = payload r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) if err != nil { @@ -668,10 +697,14 @@ func TestIPv6Send(t *testing.T) { } func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -695,10 +728,10 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv6Addr + nic.tester.dstAddr = localIpv6Addr + nic.tester.contents = view[header.IPv6MinimumSize:totalLen] r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) if err != nil { @@ -712,8 +745,8 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -752,10 +785,14 @@ func TestIPv6ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -814,19 +851,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv6Addr + nic.tester.dstAddr = localIpv6Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } -- cgit v1.2.3