From ac2200b8a9c269926d2eb98a7c23be79b4738fcf Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Mon, 26 Aug 2019 12:28:26 -0700 Subject: Prevent a network endpoint to send/rcv if its address was removed This addresses the problem where an endpoint has its address removed but still has outstanding references held by routes used in connected TCP/UDP sockets which prevent the removal of the endpoint. The fix adds a new "expired" flag to the referenced network endpoint, which is set when an endpoint has its address removed. Incoming packets are not delivered to an expired endpoint (unless in promiscuous mode), while sending outgoing packets triggers an error to the caller (unless in spoofing mode). In addition, a few helper functions were added to stack_test.go to reduce code duplications. PiperOrigin-RevId: 265514326 --- pkg/tcpip/stack/stack_test.go | 507 +++++++++++++++++++++++++++++++++--------- 1 file changed, 398 insertions(+), 109 deletions(-) (limited to 'pkg/tcpip/stack/stack_test.go') diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 137c6183e..4debd1eec 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -181,6 +181,10 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } +func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { + return f.packetCount[int(intfAddr)%len(f.packetCount)] +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } @@ -289,16 +293,75 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatal("FindRoute failed:", err) + return err } defer r.Release() + return send(r, payload) +} +func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Error("WritePacket failed:", err) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) +} + +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := sendTo(s, addr, payload); err != nil { + t.Error("sendTo failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("sendTo packet count: got = %d, want %d", got, want) + } +} + +func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := send(r, payload); err != nil { + t.Error("send failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("send packet count: got = %d, want %d", got, want) + } +} + +func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := send(r, payload); gotErr != wantErr { + t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := sendTo(s, addr, payload); gotErr != wantErr { + t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + 1 + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we do NOT expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { + t.Helper() + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if got := fakeNet.PacketCount(localAddrByte); got != want { + t.Errorf("receive packet count: got = %d, want %d", got, want) } } @@ -325,10 +388,7 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - sendTo(t, s, "\x03", nil) - if c := linkEP.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x03", linkEP, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -382,18 +442,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - sendTo(t, s, "\x05", nil) - - if c := linkEP1.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x05", linkEP1, nil) // Send a packet to an even destination. - sendTo(t, s, "\x06", nil) - - if c := linkEP2.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x06", linkEP2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -498,6 +550,10 @@ func TestRoutes(t *testing.T) { } func TestAddressRemoval(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -505,51 +561,56 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet doesn't get delivered - // anymore. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } -func TestDelayedRemovalDueToRoute(t *testing.T) { +func TestAddressRemovalWithRouteHeld(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { @@ -558,50 +619,227 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - - // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 2 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet is still deliverable - // because the route is keeping the address alive. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } +} - // Release the route, then check that packet is not deliverable anymore. - r.Release() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) +func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { + t.Helper() + info, ok := s.NICInfo()[nicid] + if !ok { + t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + } + if len(addr) == 0 { + // No address given, verify that there is no address assigned to the NIC. + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + } + } + return + } + // Address given, verify the address is assigned to the NIC and no other + // address is. + found := false + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber { + if a.AddressWithPrefix.Address == addr { + found = true + } else { + t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) + } + } + } + if !found { + t.Errorf("verify address: couldn't find %s on the NIC", addr) + } +} + +func TestEndpointExpiration(t *testing.T) { + const ( + localAddrByte byte = 0x01 + remoteAddr tcpip.Address = "\x03" + noAddr tcpip.Address = "" + nicid tcpip.NICID = 1 + ) + localAddr := tcpip.Address([]byte{localAddrByte}) + + for _, promiscuous := range []bool{true, false} { + for _, spoofing := range []bool{true, false} { + t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) + buf[0] = localAddrByte + + if promiscuous { + if err := s.SetPromiscuousMode(nicid, true); err != nil { + t.Fatal("SetPromiscuousMode failed:", err) + } + } + + if spoofing { + if err := s.SetSpoofing(nicid, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + } + + // 1. No Address yet, send should only work for spoofing, receive for + // promiscuous mode. + //----------------------- + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 2. Add Address, everything should work. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 3. Remove the address, send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 4. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 5. Take a reference to the endpoint by getting a route. Verify that + // we can still send/receive, including sending using the route. + //----------------------- + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 6. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 7. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 8. Remove the route, sendTo/recv should still work. + //----------------------- + r.Release() + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 9. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + }) + } } } @@ -627,22 +865,15 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + const localAddrByte byte = 0x01 + buf[0] = localAddrByte + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -655,25 +886,22 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) } -func TestAddressSpoofing(t *testing.T) { - srcAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") +func TestSpoofingWithAddress(t *testing.T) { + localAddr := tcpip.Address("\x01") + nonExistentLocalAddr := tcpip.Address("\x02") + dstAddr := tcpip.Address("\x03") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") + id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } @@ -687,7 +915,7 @@ func TestAddressSpoofing(t *testing.T) { // With address spoofing disabled, FindRoute does not permit an address // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err == nil { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } @@ -697,16 +925,81 @@ func TestAddressSpoofing(t *testing.T) { if err := s.SetSpoofing(1, true); err != nil { t.Fatal("SetSpoofing failed:", err) } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } + // Sending a packet works. + testSendTo(t, s, dstAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // FindRoute should also work with a local address that exists on the NIC. + r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - if r.LocalAddress != srcAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + if r.LocalAddress != localAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } + // Sending a packet using the route works. + testSend(t, r, linkEP, nil) +} + +func TestSpoofingNoAddress(t *testing.T) { + nonExistentLocalAddr := tcpip.Address("\x01") + dstAddr := tcpip.Address("\x02") + + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + // Sending a packet fails. + testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } + // Sending a packet works. + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { @@ -856,8 +1149,8 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -866,10 +1159,7 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) } // Set the subnet, then check that CheckLocalAddress returns the correct NIC. @@ -939,8 +1229,8 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -948,10 +1238,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) } func TestNetworkOptions(t *testing.T) { @@ -1305,7 +1592,7 @@ func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatal("CreateNIC failed:", err) + t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) @@ -1332,7 +1619,9 @@ func TestNICStats(t *testing.T) { payload := buffer.NewView(10) // Write a packet out via the address for NIC 1 - sendTo(t, s, "\x01", payload) + if err := sendTo(s, "\x01", payload); err != nil { + t.Fatal("sendTo failed: ", err) + } want := uint64(linkEP1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) -- cgit v1.2.3 From 7bf1d426d5f8ba5a3ddfc8f0cd73cc12e8a83c16 Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Wed, 4 Sep 2019 14:18:02 -0700 Subject: Handle subnet and broadcast addresses correctly with NIC.subnets This also renames "subnet" to "addressRange" to avoid any more confusion with an interface IP's subnet. Lastly, this also removes the Stack.ContainsSubnet(..) API since it isn't used by anyone. Plus the same information can be obtained from Stack.NICAddressRanges(). PiperOrigin-RevId: 267229843 --- pkg/tcpip/stack/nic.go | 63 +++++++++++----------- pkg/tcpip/stack/stack.go | 32 ++++------- pkg/tcpip/stack/stack_test.go | 121 ++++++++++++++++++++++++++---------------- pkg/tcpip/tcpip.go | 9 ++++ 4 files changed, 129 insertions(+), 96 deletions(-) (limited to 'pkg/tcpip/stack/stack_test.go') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ae56e0ffd..43719085e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -36,13 +36,13 @@ type NIC struct { demux *transportDemuxer - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - subnets []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -224,7 +224,17 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.subnets { + for _, sn := range n.addressRanges { + // Skip the subnet address. + if address == sn.ID() { + continue + } + // For now just skip the broadcast address, until we support it. + // FIXME(b/137608825): Add support for sending/receiving directed + // (subnet) broadcast. + if address == sn.Broadcast() { + continue + } if sn.Contains(address) { createTempEP = true break @@ -381,45 +391,38 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } -// AddSubnet adds a new subnet to n, so that it starts accepting packets -// targeted at the given address and network protocol. -func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { +// AddAddressRange adds a range of addresses to n, so that it starts accepting +// packets targeted at the given addresses and network protocol. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.subnets = append(n.subnets, subnet) + n.addressRanges = append(n.addressRanges, subnet) n.mu.Unlock() } -// RemoveSubnet removes the given subnet from n. -func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { +// RemoveAddressRange removes the given address range from n. +func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.subnets[:0] - for _, sub := range n.subnets { + tmp := n.addressRanges[:0] + for _, sub := range n.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.subnets = tmp + n.addressRanges = tmp n.mu.Unlock() } -// ContainsSubnet reports whether this NIC contains the given subnet. -func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { - for _, s := range n.Subnets() { - if s == subnet { - return true - } - } - return false -} - // Subnets returns the Subnets associated with this NIC. -func (n *NIC) Subnets() []tcpip.Subnet { +func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) for nid := range n.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { @@ -429,7 +432,7 @@ func (n *NIC) Subnets() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.subnets...) + return append(sns, n.addressRanges...) } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 1d5e84a8b..6beca6ae8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -702,14 +702,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // NICSubnets returns a map of NICIDs to their associated subnets. -func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet { +func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() nics := map[tcpip.NICID][]tcpip.Subnet{} for id, nic := range s.nics { - nics[id] = append(nics[id], nic.Subnets()...) + nics[id] = append(nics[id], nic.AddressRanges()...) } return nics } @@ -810,45 +810,35 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return nic.AddAddress(protocolAddress, peb) } -// AddSubnet adds a subnet range to the specified NIC. -func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { +// AddAddressRange adds a range of addresses to the specified NIC. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.AddSubnet(protocol, subnet) + nic.AddAddressRange(protocol, subnet) return nil } return tcpip.ErrUnknownNICID } -// RemoveSubnet removes the subnet range from the specified NIC. -func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { +// RemoveAddressRange removes the range of addresses from the specified NIC. +func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.RemoveSubnet(subnet) + nic.RemoveAddressRange(subnet) return nil } return tcpip.ErrUnknownNICID } -// ContainsSubnet reports whether the specified NIC contains the specified -// subnet. -func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - return nic.ContainsSubnet(subnet), nil - } - - return false, tcpip.ErrUnknownNICID -} - // RemoveAddress removes an existing network-layer address from the specified // NIC. func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 4debd1eec..c6a8160af 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1128,8 +1128,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Set the subnet, then check that packet is delivered. -func TestSubnetAcceptsMatchingPacket(t *testing.T) { +// Add a range of addresses, then check that a packet is delivered. +func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -1155,14 +1155,45 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } testRecv(t, fakeNet, localAddrByte, linkEP, buf) } -// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { + t.Helper() + + // Loop over all addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + + addrBytes := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + addr := tcpip.Address(addrBytes) + wantNicID := nicID + // The subnet and broadcast addresses are skipped. + if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { + wantNicID = 0 + } + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) + } + addrBytes[0]++ + } + + // Trying the next address should always fail since it is outside the range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) + } +} + +// Set a range of addresses, then remove it again, and check at each step that +// CheckLocalAddress returns the correct NIC for each address or zero if not +// existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) @@ -1181,35 +1212,28 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { } subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) - } - // Loop over all subnet addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - addr := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) - } - addr[0]++ + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) + + if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - // Trying the next address should fail since it is outside the subnet range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) + testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) + + if err := s.RemoveAddressRange(nicID, subnet); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } + + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) } -// Set destination outside the subnet, then check it doesn't get delivered. -func TestSubnetRejectsNonmatchingPacket(t *testing.T) { +// Set a range of addresses, then send a packet to a destination outside the +// range and then check it doesn't get delivered. +func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -1235,8 +1259,8 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) } @@ -1281,7 +1305,20 @@ func TestNetworkOptions(t *testing.T) { } } -func TestSubnetAddRemove(t *testing.T) { +func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { + ranges, ok := s.NICAddressRanges()[id] + if !ok { + return false + } + for _, r := range ranges { + if r == addrRange { + return true + } + } + return false +} + +func TestAddresRangeAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { @@ -1290,35 +1327,29 @@ func TestSubnetAddRemove(t *testing.T) { addr := tcpip.Address("\x01\x01\x01\x01") mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - subnet, err := tcpip.NewSubnet(addr, mask) + addrRange, err := tcpip.NewSubnet(addr, mask) if err != nil { t.Fatal("NewSubnet failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { + t.Fatal("AddAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if !contained { - t.Fatal("got s.ContainsSubnet(...) = false, want = true") + if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatal("RemoveSubnet failed:", err) + if err := s.RemoveAddressRange(1, addrRange); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 05aa42c98..418e771d2 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -219,6 +219,15 @@ func (s *Subnet) Mask() AddressMask { return s.mask } +// Broadcast returns the subnet's broadcast address. +func (s *Subnet) Broadcast() Address { + addr := []byte(s.address) + for i := range addr { + addr[i] |= ^s.mask[i] + } + return Address(addr) +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 -- cgit v1.2.3 From fe1f5210774d015d653df164d6f676658863780c Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Fri, 6 Sep 2019 17:59:46 -0700 Subject: Remove reundant global tcpip.LinkEndpointID. PiperOrigin-RevId: 267709597 --- pkg/tcpip/link/channel/channel.go | 6 +- pkg/tcpip/link/fdbased/endpoint.go | 18 +- pkg/tcpip/link/fdbased/endpoint_test.go | 3 +- pkg/tcpip/link/loopback/loopback.go | 4 +- pkg/tcpip/link/muxed/injectable.go | 5 +- pkg/tcpip/link/muxed/injectable_test.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 8 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 4 +- pkg/tcpip/link/sniffer/sniffer.go | 18 +- pkg/tcpip/link/waitable/waitable.go | 7 +- pkg/tcpip/link/waitable/waitable_test.go | 6 +- pkg/tcpip/network/arp/arp_test.go | 10 +- pkg/tcpip/network/ipv4/ipv4_test.go | 22 +- pkg/tcpip/network/ipv6/icmp_test.go | 22 +- pkg/tcpip/network/ipv6/ndp_test.go | 15 +- pkg/tcpip/stack/registration.go | 28 --- pkg/tcpip/stack/stack.go | 27 +- pkg/tcpip/stack/stack_test.go | 278 ++++++++++----------- pkg/tcpip/stack/transport_test.go | 24 +- pkg/tcpip/tcpip.go | 3 - pkg/tcpip/transport/tcp/testing/context/context.go | 9 +- pkg/tcpip/transport/udp/udp_test.go | 9 +- runsc/boot/network.go | 18 +- 23 files changed, 250 insertions(+), 298 deletions(-) (limited to 'pkg/tcpip/stack/stack_test.go') diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c40744b8e..eec430d0a 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -44,14 +44,12 @@ type Endpoint struct { } // New creates a new channel endpoint. -func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { + return &Endpoint{ C: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } - - return stack.RegisterLinkEndpoint(e), e } // Drain removes all outbound packets from the channel and counts them. diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 77f988b9f..adcf21371 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -165,7 +165,7 @@ type Options struct { // // Makes fd non-blocking, but does not take ownership of fd, which must remain // open for the lifetime of the returned endpoint. -func New(opts *Options) (tcpip.LinkEndpointID, error) { +func New(opts *Options) (stack.LinkEndpoint, error) { caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -190,7 +190,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } if len(opts.FDs) == 0 { - return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } e := &endpoint{ @@ -207,12 +207,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { for i := 0; i < len(e.fds); i++ { fd := e.fds[i] if err := syscall.SetNonblock(fd, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) + return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } isSocket, err := isSocketFD(fd) if err != nil { - return 0, err + return nil, err } if isSocket { if opts.GSOMaxSize != 0 { @@ -222,12 +222,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) } e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { @@ -435,14 +435,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf } // NewInjectable creates a new fd-based InjectableEndpoint. -func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) { +func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint { syscall.SetNonblock(fd, true) - e := &InjectableEndpoint{endpoint: endpoint{ + return &InjectableEndpoint{endpoint: endpoint{ fds: []int{fd}, mtu: mtu, caps: capabilities, }} - - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e305252d6..04406bc9a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context { } opt.FDs = []int{fds[1]} - epID, err := New(opt) + ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } - ep := stack.FindLinkEndpoint(epID).(*endpoint) c := &context{ t: t, diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ab6a53988..e121ea1a5 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -32,8 +32,8 @@ type endpoint struct { // New creates a new loopback endpoint. This link-layer endpoint just turns // outbound packets into inbound packets. -func New() tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{}) +func New() stack.LinkEndpoint { + return &endpoint{} } // Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a577a3d52..3ed7b98d1 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -105,9 +105,8 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) * } // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) { - e := &InjectableEndpoint{ +func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { + return &InjectableEndpoint{ routes: routes, } - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 174b9330f..3086fec00 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc if err != nil { t.Fatal("Failed to create socket pair:", err) } - _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) + underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - _, endpoint := NewInjectableEndpoint(routes) + endpoint := NewInjectableEndpoint(routes) return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 834ea5c40..ba387af73 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -94,7 +94,7 @@ type endpoint struct { // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { e := &endpoint{ mtu: mtu, bufferSize: bufferSize, @@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc } if err := e.tx.init(bufferSize, &tx); err != nil { - return 0, err + return nil, err } if err := e.rx.init(bufferSize, &rx); err != nil { e.tx.cleanup() - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } // Close frees all resources associated with the endpoint. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 98036f367..0e9ba0846 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) if err != nil { t.Fatalf("New failed: %v", err) } - c.ep = stack.FindLinkEndpoint(id).(*endpoint) + c.ep = ep.(*endpoint) c.ep.Attach(c) return c diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 36c8c46fc..e7b6d7912 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -58,10 +58,10 @@ type endpoint struct { // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. -func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), - }) +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + return &endpoint{ + lower: lower, + } } func zoneOffset() (int32, error) { @@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { // snapLen is the maximum amount of a packet to be saved. Packets with a length // less than or equal too snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { +func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { if err := writePCAPHeader(file, snapLen); err != nil { - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), + return &endpoint{ + lower: lower, file: file, maxPCAPLen: snapLen, - }), nil + }, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 3b6ac2ff7..408cc62f7 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -40,11 +40,10 @@ type Endpoint struct { // New creates a new waitable link-layer endpoint. It wraps around another // endpoint and allows the caller to block new write/dispatch calls and wait for // the inflight ones to finish before returning. -func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - lower: stack.FindLinkEndpoint(lower), +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, } - return stack.RegisterLinkEndpoint(e), e } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 56e18ecb0..1031438b1 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -72,7 +72,7 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Write and check that it goes through. wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) @@ -97,7 +97,7 @@ func TestWaitWrite(t *testing.T) { func TestWaitDispatch(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Check that attach happens. wep.Attach(ep) @@ -139,7 +139,7 @@ func TestOtherMethods(t *testing.T) { hdrLen: hdrLen, linkAddr: linkAddr, } - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) if v := wep.MTU(); v != mtu { t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 4c4b54469..387fca96e 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -47,11 +47,13 @@ func newTestContext(t *testing.T) *testContext { s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -73,7 +75,7 @@ func newTestContext(t *testing.T) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1b5a55bea..ae827ca27 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,11 +36,11 @@ func TestExcludeBroadcast(t *testing.T) { s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { - id = sniffer.New(id) + ep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -184,15 +184,12 @@ type errorChannel struct { // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), Ch: make(chan packetInfo, size), packetCollectorErrors: packetCollectorErrors, } - - return stack.RegisterLinkEndpoint(e), &ec } // packetInfo holds all the information about an outbound packet. @@ -242,9 +239,8 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -266,7 +262,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 } return context{ Route: r, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 227a65cf2..a6a1a5232 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -83,8 +83,7 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li func TestICMPCounts(t *testing.T) { s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -211,14 +210,13 @@ func newTestContext(t *testing.T) *testContext { } const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { - id0 = sniffer.New(id0) + wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, id0); err != nil { + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -228,11 +226,9 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress sn lladdr0: %v", err) } - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8e4cf0e74..571915d3f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -32,15 +32,14 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Helper() s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) - { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) } + { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 67b70b2ee..88a698b18 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -379,10 +377,6 @@ var ( networkProtocols = make(map[string]NetworkProtocolFactory) unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) ) // RegisterTransportProtocolFactory registers a new transport protocol factory @@ -406,28 +400,6 @@ func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { unassociatedFactory = f } -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v -} - -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() - - return linkEndpoints[id] -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6beca6ae8..a961e8ebe 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -620,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint - } - +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -645,33 +640,33 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c6a8160af..0c26c9911 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct { prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { @@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -189,14 +189,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -225,9 +225,9 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -245,7 +245,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -255,7 +255,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -265,7 +265,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -274,7 +274,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -284,7 +284,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -307,59 +307,59 @@ func send(r stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) } -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := sendTo(s, addr, payload); err != nil { t.Error("sendTo failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("sendTo packet count: got = %d, want %d", got, want) } } -func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := send(r, payload); err != nil { t.Error("send failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("send packet count: got = %d, want %d", got, want) } } -func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) } } -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we expect to receive it. want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we do NOT expect to receive it. want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -369,9 +369,9 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -388,7 +388,7 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", linkEP, nil) + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -397,8 +397,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -410,8 +410,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -442,10 +442,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - testSendTo(t, s, "\x05", linkEP1, nil) + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - testSendTo(t, s, "\x06", linkEP2, nil) + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -478,8 +478,8 @@ func TestRoutes(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -491,8 +491,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -556,8 +556,8 @@ func TestAddressRemoval(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -578,15 +578,15 @@ func TestAddressRemoval(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -601,9 +601,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) @@ -626,17 +626,17 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -690,8 +690,8 @@ func TestEndpointExpiration(t *testing.T) { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -724,15 +724,15 @@ func TestEndpointExpiration(t *testing.T) { //----------------------- verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 2. Add Address, everything should work. @@ -741,8 +741,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. @@ -752,15 +752,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 4. Add Address back, everything should work again. @@ -769,8 +769,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 5. Take a reference to the endpoint by getting a route. Verify that // we can still send/receive, including sending using the route. @@ -779,9 +779,9 @@ func TestEndpointExpiration(t *testing.T) { if err != nil { t.Fatal("FindRoute failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -791,16 +791,16 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 7. Add Address back, everything should work again. @@ -809,16 +809,16 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -828,15 +828,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } }) } @@ -846,8 +846,8 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -867,13 +867,13 @@ func TestPromiscuousMode(t *testing.T) { // have a matching endpoint. const localAddrByte byte = 0x01 buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -886,7 +886,7 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestSpoofingWithAddress(t *testing.T) { @@ -896,8 +896,8 @@ func TestSpoofingWithAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -936,8 +936,8 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet works. - testSendTo(t, s, dstAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) // FindRoute should also work with a local address that exists on the NIC. r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) @@ -951,7 +951,7 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. - testSend(t, r, linkEP, nil) + testSend(t, r, ep, nil) } func TestSpoofingNoAddress(t *testing.T) { @@ -960,8 +960,8 @@ func TestSpoofingNoAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -980,7 +980,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -999,14 +999,14 @@ func TestSpoofingNoAddress(t *testing.T) { } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -1076,8 +1076,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1132,8 +1132,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1159,7 +1159,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddAddressRange failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { @@ -1198,8 +1198,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1236,8 +1236,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1262,7 +1262,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddAddressRange failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { @@ -1320,8 +1320,8 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S func TestAddresRangeAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1361,8 +1361,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert primary and never-primary addresses. @@ -1426,8 +1426,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1501,8 +1501,8 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1526,8 +1526,8 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1558,8 +1558,8 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1587,8 +1587,8 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1621,8 +1621,8 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { @@ -1639,7 +1639,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1653,9 +1653,9 @@ func TestNICStats(t *testing.T) { if err := sendTo(s, "\x01", payload); err != nil { t.Fatal("sendTo failed: ", err) } - want := uint64(linkEP1.Drain()) + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1669,16 +1669,16 @@ func TestNICForwarding(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1697,10 +1697,10 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ca185279e..87d1e0d0d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -278,9 +278,9 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -340,9 +340,9 @@ func TestTransportReceive(t *testing.T) { } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -408,9 +408,9 @@ func TestTransportControlReceive(t *testing.T) { } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -497,16 +497,16 @@ func TestTransportForwarding(t *testing.T) { s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +545,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.Inject(fakeNetNumber, req.ToVectorisedView()) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,7 +559,7 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 418e771d2..ebf8a2d04 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -600,9 +600,6 @@ func (r Route) String() string { return out.String() } -// LinkEndpointID represents a data link layer endpoint. -type LinkEndpointID uint64 - // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 18c707a57..16783e716 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -150,11 +150,12 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - id, linkEP := channel.New(1000, mtu, "") + ep := channel.New(1000, mtu, "") + wep := stack.LinkEndpoint(ep) if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -180,7 +181,7 @@ func New(t *testing.T, mtu uint32) *Context { return &Context{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), } } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 995d6e8a1..c6deab892 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -275,12 +275,13 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + ep := channel.New(256, mtu, "") + wep := stack.LinkEndpoint(ep) - id, linkEP := channel.New(256, mtu, "") if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -306,7 +307,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/runsc/boot/network.go b/runsc/boot/network.go index ea0d9f790..32cba5ac1 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -121,10 +121,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct nicID++ nicids[link.Name] = nicID - linkEP := loopback.New() + ep := loopback.New() log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil { return err } @@ -156,7 +156,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } mac := tcpip.LinkAddress(link.LinkAddress) - linkEP, err := fdbased.New(&fdbased.Options{ + ep, err := fdbased.New(&fdbased.Options{ FDs: FDs, MTU: uint32(link.MTU), EthernetHeader: true, @@ -170,7 +170,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil { return err } @@ -203,14 +203,14 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct // createNICWithAddrs creates a NIC in the network stack and adds the given // addresses. -func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, addrs []net.IP, loopback bool) error { +func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error { if loopback { - if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(linkEP)); err != nil { - return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, linkEP, err) + if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil { + return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err) } } else { - if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(linkEP)); err != nil { - return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, linkEP, err) + if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil { + return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err) } } -- cgit v1.2.3 From 6704d625ef2726a54078f4179d0eaa2c820cc401 Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Tue, 24 Sep 2019 13:18:19 -0700 Subject: Return only primary addresses in Stack.NICInfo() Non-primary addresses are used for endpoints created to accept multicast and broadcast packets, as well as "helper" endpoints (0.0.0.0) that allow sending packets when no proper address has been assigned yet (e.g., for DHCP). These addresses are not real addresses from a user point of view and should not be part of the NICInfo() value. Also see b/127321246 for more info. This switches NICInfo() to call a new NIC.PrimaryAddresses() function. To still allow an option to get all addresses (mostly for testing) I added Stack.GetAllAddresses() and NIC.AllAddresses(). In addition, the return value for GetMainNICAddress() was changed for the case where the NIC has no primary address. Instead of returning an error here, it now returns an empty AddressWithPrefix() value. The rational for this change is that it is a valid case for a NIC to have no primary addresses. Lastly, I refactored the code based on the new additions. PiperOrigin-RevId: 270971764 --- pkg/tcpip/stack/nic.go | 65 +++++++++++++++++++++---------------------- pkg/tcpip/stack/stack.go | 34 ++++++++++++++++------ pkg/tcpip/stack/stack_test.go | 64 +++++++++++++++++++++++------------------- 3 files changed, 94 insertions(+), 69 deletions(-) (limited to 'pkg/tcpip/stack/stack_test.go') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a719058b4..0e8a23f00 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -148,37 +148,6 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { - n.mu.RLock() - defer n.mu.RUnlock() - - var r *referencedNetworkEndpoint - - // Check for a primary endpoint. - if list, ok := n.primary[protocol]; ok { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - if ref.getKind() == permanent && ref.tryIncRef() { - r = ref - break - } - } - - } - - if r == nil { - return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress - } - - addressWithPrefix := tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } - r.decRef() - - return addressWithPrefix, nil -} - // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { @@ -398,10 +367,12 @@ func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return err } -// Addresses returns the addresses associated with this NIC. -func (n *NIC) Addresses() []tcpip.ProtocolAddress { +// AllAddresses returns all addresses (primary and non-primary) associated with +// this NIC. +func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { // Don't include expired or tempory endpoints to avoid confusion and @@ -421,6 +392,34 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } +// PrimaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + + var addrs []tcpip.ProtocolAddress + for proto, list := range n.primary { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } + + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, + }) + } + } + return addrs +} + // AddAddressRange adds a range of addresses to n, so that it starts accepting // packets targeted at the given addresses and network protocol. The range is // given by a subnet address, and all addresses contained in the subnet are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 1fe21b68e..f7ba3cb0f 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -738,7 +738,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.Addresses(), + ProtocolAddresses: nic.PrimaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -845,19 +845,37 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { return tcpip.ErrUnknownNICID } -// GetMainNICAddress returns the first primary address (and the subnet that -// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's -// address if no primary addresses exist. Returns an error if the NIC doesn't -// exist or has no endpoints. +// AllAddresses returns a map of NICIDs to their protocol addresses (primary +// and non-primary). +func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) + for id, nic := range s.nics { + nics[id] = nic.AllAddresses() + } + return nics +} + +// GetMainNICAddress returns the first primary address and prefix for the given +// NIC and protocol. Returns an error if the NIC doesn't exist and an empty +// value if the NIC doesn't have a primary address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.getMainNICAddress(protocol) + nic, ok := s.nics[id] + if !ok { + return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + for _, a := range nic.PrimaryAddresses() { + if a.Protocol == protocol { + return a.AddressWithPrefix, nil + } + } + return tcpip.AddressWithPrefix{}, nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 0c26c9911..7aa10bce9 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -930,10 +930,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet works. testSendTo(t, s, dstAddr, ep, nil) @@ -945,10 +945,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. testSend(t, r, ep, nil) @@ -992,10 +992,10 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. @@ -1400,20 +1400,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } if len(primaryAddrAdded) == 0 { - // No primary addresses present, expect an error. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) + // No primary addresses present. + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) } } else { - // At least one primary address was added, expect a valid - // address and prefixLen. - gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { - t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) + // At least one primary address was added, verify the returned + // address is in the list of primary addresses we added. + if _, ok := primaryAddrAdded[gotAddr]; !ok { + t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) } } }) @@ -1452,19 +1452,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) + } + if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { + t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } - // Check that we get an error after removal. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) + // Check that we get no address after removal. + gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } }) } @@ -1479,8 +1485,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address { } func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + t.Helper() + if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) } sort.Slice(gotAddresses, func(i, j int) bool { @@ -1519,7 +1527,7 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1551,7 +1559,7 @@ func TestAddProtocolAddress(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1580,7 +1588,7 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1615,7 +1623,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } -- cgit v1.2.3 From 59ccbb10446063f5347fb026e35549bc2f677971 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 25 Sep 2019 12:56:00 -0700 Subject: Remove centralized registration of protocols. Also removes the need for protocol names. PiperOrigin-RevId: 271186030 --- pkg/tcpip/adapters/gonet/gonet_test.go | 5 +- pkg/tcpip/network/arp/arp.go | 16 ++- pkg/tcpip/network/arp/arp_test.go | 5 +- pkg/tcpip/network/ip_test.go | 10 +- pkg/tcpip/network/ipv4/ipv4.go | 44 +++----- pkg/tcpip/network/ipv4/ipv4_test.go | 9 +- pkg/tcpip/network/ipv6/icmp_test.go | 15 ++- pkg/tcpip/network/ipv6/ipv6.go | 24 ++--- pkg/tcpip/network/ipv6/ipv6_test.go | 34 ++++--- pkg/tcpip/network/ipv6/ndp_test.go | 5 +- pkg/tcpip/sample/tun_tcp_connect/main.go | 5 +- pkg/tcpip/sample/tun_tcp_echo/main.go | 5 +- pkg/tcpip/stack/registration.go | 36 ------- pkg/tcpip/stack/stack.go | 45 ++++----- pkg/tcpip/stack/stack_test.go | 111 +++++++++++++++------ pkg/tcpip/stack/transport_test.go | 35 +++++-- pkg/tcpip/transport/icmp/protocol.go | 27 ++--- pkg/tcpip/transport/raw/protocol.go | 9 +- pkg/tcpip/transport/tcp/protocol.go | 22 ++-- pkg/tcpip/transport/tcp/tcp_test.go | 21 ++-- pkg/tcpip/transport/tcp/testing/context/context.go | 5 +- pkg/tcpip/transport/udp/protocol.go | 12 +-- pkg/tcpip/transport/udp/udp_test.go | 5 +- runsc/boot/BUILD | 1 + runsc/boot/loader.go | 17 ++-- 25 files changed, 278 insertions(+), 245 deletions(-) (limited to 'pkg/tcpip/stack/stack_test.go') diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 672f026b2..8ced960bb 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { return nil, err diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index fd6395fc1..26cf1c528 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -16,9 +16,9 @@ // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. // -// To use it in the networking stack, pass arp.ProtocolName as one of the -// network protocols when calling stack.New. Then add an "arp" address to -// every NIC on the stack that should respond to ARP requests. That is: +// To use it in the networking stack, pass arp.NewProtocol() as one of the +// network protocols when calling stack.New. Then add an "arp" address to every +// NIC on the stack that should respond to ARP requests. That is: // // if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { // // handle err @@ -33,9 +33,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ARP protocol name. - ProtocolName = "arp" - // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber @@ -200,8 +197,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an ARP network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 387fca96e..88b57ec03 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -44,7 +44,10 @@ type testContext struct { } func newTestContext(t *testing.T) *testContext { - s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + }) const defaultMTU = 65536 ep := channel.New(256, defaultMTU, stackLinkAddr) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6a40e7ee3..a9741622e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen } func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -185,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b7a06f525..b7b07a6c1 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -14,9 +14,9 @@ // Package ipv4 contains the implementation of the ipv4 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv4.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv4.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv4.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv4 @@ -32,9 +32,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv4 protocol name. - ProtocolName = "ipv4" - // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -53,6 +50,7 @@ type endpoint struct { linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation + protocol *protocol } // NewEndpoint creates a new ipv4 endpoint. @@ -64,6 +62,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + protocol: p, } return e, nil @@ -204,7 +203,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1) } ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -267,7 +266,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect if payload.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) } ip.SetID(uint16(id)) } @@ -325,14 +324,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() {} -type protocol struct{} - -// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} +type protocol struct { + ids []uint32 + hashIV uint32 } // Number returns the ipv4 protocol number. @@ -378,7 +372,7 @@ func calculateMTU(mtu uint32) uint32 { // hashRoute calculates a hash value for the given route. It uses the source & // destination address, the transport protocol number, and a random initial // value (generated once on initialization) to generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { t := r.LocalAddress a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 t = r.RemoteAddress @@ -386,22 +380,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -var ( - ids []uint32 - hashIV uint32 -) - -func init() { - ids = make([]uint32, buckets) +// NewProtocol returns an IPv4 network protocol. +func NewProtocol() stack.NetworkProtocol { + ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. r := hash.RandN32(1 + buckets) for i := range ids { ids[i] = r[i] } - hashIV = r[buckets] + hashIV := r[buckets] - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) + return &protocol{ids: ids, hashIV: hashIV} } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ae827ca27..b6641ccc3 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -33,7 +33,10 @@ import ( ) func TestExcludeBroadcast(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) @@ -238,7 +241,9 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. - s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) s.CreateNIC(1, ep) const ( diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 653d984e9..01f5a17ec 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -81,7 +81,10 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li } func TestICMPCounts(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) { if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -205,8 +208,14 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ - s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), + s0: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), + s1: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), } const defaultMTU = 65536 diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331a8bdaa..7de6a4546 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -14,9 +14,9 @@ // Package ipv6 contains the implementation of the ipv6 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv6.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv6.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv6.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv6 @@ -28,9 +28,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv6 protocol name. - ProtocolName = "ipv6" - // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber @@ -160,14 +157,6 @@ func (*endpoint) Close() {} type protocol struct{} -// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} -} - // Number returns the ipv6 protocol number. func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber @@ -221,8 +210,7 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an IPv6 network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 57bcd5455..78c674c2c 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -124,17 +124,20 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst // UDP packets destined to the IPv6 link-local all-nodes multicast address. func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { tests := []struct { - name string - protocolName string - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.ProtocolName6, testReceiveICMP}, - {"UDP", udp.ProtocolName, testReceiveUDP}, + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -152,19 +155,22 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { tests := []struct { - name string - protocolName string - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.ProtocolName6, testReceiveICMP}, - {"UDP", udp.ProtocolName, testReceiveUDP}, + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, } snmc := header.SolicitedNodeAddr(addr2) for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -237,7 +243,9 @@ func TestAddIpv6Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 571915d3f..e30791fe3 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -31,7 +31,10 @@ import ( func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index f12189580..2239c1e66 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -126,7 +126,10 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 329941775..bca73cbb1 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -111,7 +111,10 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 07e4c770d..80101d4bb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -366,14 +366,6 @@ type LinkAddressCache interface { RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// TransportProtocolFactory functions are used by the stack to instantiate -// transport protocols. -type TransportProtocolFactory func() TransportProtocol - -// NetworkProtocolFactory provides methods to be used by the stack to -// instantiate network protocols. -type NetworkProtocolFactory func() NetworkProtocol - // UnassociatedEndpointFactory produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can be used // to write arbitrary packets that include the IP header. @@ -381,34 +373,6 @@ type UnassociatedEndpointFactory interface { NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } -var ( - transportProtocols = make(map[string]TransportProtocolFactory) - networkProtocols = make(map[string]NetworkProtocolFactory) - - unassociatedFactory UnassociatedEndpointFactory -) - -// RegisterTransportProtocolFactory registers a new transport protocol factory -// with the stack so that it becomes available to users of the stack. This -// function is intended to be called by init() functions of the protocols. -func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { - transportProtocols[name] = p -} - -// RegisterNetworkProtocolFactory registers a new network protocol factory with -// the stack so that it becomes available to users of the stack. This function -// is intended to be called by init() functions of the protocols. -func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { - networkProtocols[name] = p -} - -// RegisterUnassociatedFactory registers a factory to produce endpoints not -// associated with any particular transport protocol. This function is intended -// to be called by init() functions of the protocols. -func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { - unassociatedFactory = f -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f7ba3cb0f..18d1704a5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -17,11 +17,6 @@ // // For consumers, the only function of interest is New(), everything else is // provided by the tcpip/public package. -// -// For protocol implementers, RegisterTransportProtocolFactory() and -// RegisterNetworkProtocolFactory() are used to register protocol factories with -// the stack, which will then be used to instantiate protocol objects when -// consumers interact with the stack. package stack import ( @@ -351,6 +346,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // unassociatedFactory creates unassociated endpoints. If nil, raw + // endpoints are disabled. It is set during Stack creation and is + // immutable. unassociatedFactory UnassociatedEndpointFactory demux *transportDemuxer @@ -359,10 +357,6 @@ type Stack struct { linkAddrCache *linkAddrCache - // raw indicates whether raw sockets may be created. It is set during - // Stack creation and is immutable. - raw bool - mu sync.RWMutex nics map[tcpip.NICID]*NIC forwarding bool @@ -398,6 +392,12 @@ type Stack struct { // Options contains optional Stack configuration. type Options struct { + // NetworkProtocols lists the network protocols to enable. + NetworkProtocols []NetworkProtocol + + // TransportProtocols lists the transport protocols to enable. + TransportProtocols []TransportProtocol + // Clock is an optional clock source used for timestampping packets. // // If no Clock is specified, the clock source will be time.Now. @@ -411,8 +411,9 @@ type Options struct { // stack (false). HandleLocal bool - // Raw indicates whether raw sockets may be created. - Raw bool + // UnassociatedFactory produces unassociated endpoints raw endpoints. + // Raw endpoints are enabled only if this is non-nil. + UnassociatedFactory UnassociatedEndpointFactory } // New allocates a new networking stack with only the requested networking and @@ -422,7 +423,7 @@ type Options struct { // SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the // stack. Please refer to individual protocol implementations as to what options // are supported. -func New(network []string, transport []string, opts Options) *Stack { +func New(opts Options) *Stack { clock := opts.Clock if clock == nil { clock = &tcpip.StdClock{} @@ -438,17 +439,11 @@ func New(network []string, transport []string, opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - raw: opts.Raw, icmpRateLimiter: NewICMPRateLimiter(), } // Add specified network protocols. - for _, name := range network { - netProtoFactory, ok := networkProtocols[name] - if !ok { - continue - } - netProto := netProtoFactory() + for _, netProto := range opts.NetworkProtocols { s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -456,18 +451,14 @@ func New(network []string, transport []string, opts Options) *Stack { } // Add specified transport protocols. - for _, name := range transport { - transProtoFactory, ok := transportProtocols[name] - if !ok { - continue - } - transProto := transProtoFactory() + for _, transProto := range opts.TransportProtocols { s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } } - s.unassociatedFactory = unassociatedFactory + // Add the factory for unassociated endpoints, if present. + s.unassociatedFactory = opts.UnassociatedFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -602,7 +593,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if !s.raw { + if s.unassociatedFactory == nil { return nil, tcpip.ErrNotPermitted } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7aa10bce9..d2dede8a9 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -222,11 +222,17 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeNetFactory() stack.NetworkProtocol { + return &fakeNetworkProtocol{} +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. ep := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -370,7 +376,9 @@ func TestNetworkSend(t *testing.T) { // address: 1. The route table sends all packets through the only // existing nic. ep := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -395,7 +403,9 @@ func TestNetworkSendMultiRoute(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -476,7 +486,9 @@ func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -554,7 +566,9 @@ func TestAddressRemoval(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -599,7 +613,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -688,7 +704,9 @@ func TestEndpointExpiration(t *testing.T) { for _, promiscuous := range []bool{true, false} { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { @@ -844,7 +862,9 @@ func TestEndpointExpiration(t *testing.T) { } func TestPromiscuousMode(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -894,7 +914,9 @@ func TestSpoofingWithAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x02") dstAddr := tcpip.Address("\x03") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -958,7 +980,9 @@ func TestSpoofingNoAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x01") dstAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1003,7 +1027,9 @@ func TestSpoofingNoAddress(t *testing.T) { } func TestBroadcastNeedsNoRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1074,7 +1100,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, } { t.Run(tc.name, func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1130,7 +1158,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { // Add a range of addresses, then check that a packet is delivered. func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1196,7 +1226,9 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub // existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1234,7 +1266,9 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { // Set a range of addresses, then send a packet to a destination outside the // range and then check it doesn't get delivered. func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1266,7 +1300,10 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { } func TestNetworkOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{}, + }) // Try an unsupported network protocol. if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -1319,7 +1356,9 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S } func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1360,7 +1399,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1425,7 +1466,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1508,7 +1551,9 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1533,7 +1578,9 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1565,7 +1612,9 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1594,7 +1643,9 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1628,7 +1679,9 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } func TestNICStats(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) @@ -1674,7 +1727,9 @@ func TestNICStats(t *testing.T) { func TestNICForwarding(t *testing.T) { // Create a stack with the fake network protocol, two NICs, each with // an address. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) s.SetForwarding(true) ep1 := channel.New(10, defaultMTU, "") @@ -1722,9 +1777,3 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } - -func init() { - stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { - return &fakeNetworkProtocol{} - }) -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 0e69ac7c8..56e8a5d9b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -282,9 +282,16 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeTransFactory() stack.TransportProtocol { + return &fakeTransportProtocol{} +} + func TestTransportReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -346,7 +353,10 @@ func TestTransportReceive(t *testing.T) { func TestTransportControlReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -414,7 +424,10 @@ func TestTransportControlReceive(t *testing.T) { func TestTransportSend(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -457,7 +470,10 @@ func TestTransportSend(t *testing.T) { } func TestTransportOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) // Try an unsupported transport protocol. if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -498,7 +514,10 @@ func TestTransportOptions(t *testing.T) { } func TestTransportForwarding(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. @@ -576,9 +595,3 @@ func TestTransportForwarding(t *testing.T) { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } - -func init() { - stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { - return &fakeTransportProtocol{} - }) -} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 1eb790932..bfb16f7c3 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -14,10 +14,9 @@ // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport // protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and -// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or -// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when -// calling stack.New(). Then endpoints can be created by passing +// must be added to the project, and activated on the stack by passing +// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport +// protocols when calling stack.New(). Then endpoints can be created by passing // icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number // when calling Stack.NewEndpoint(). package icmp @@ -34,15 +33,9 @@ import ( ) const ( - // ProtocolName4 is the string representation of the icmp protocol name. - ProtocolName4 = "icmp4" - // ProtocolNumber4 is the ICMP protocol number. ProtocolNumber4 = header.ICMPv4ProtocolNumber - // ProtocolName6 is the string representation of the icmp protocol name. - ProtocolName6 = "icmp6" - // ProtocolNumber6 is the IPv6-ICMP protocol number. ProtocolNumber6 = header.ICMPv6ProtocolNumber ) @@ -125,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol { - return &protocol{ProtocolNumber4} - }) +// NewProtocol4 returns an ICMPv4 transport protocol. +func NewProtocol4() stack.TransportProtocol { + return &protocol{ProtocolNumber4} +} - stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol { - return &protocol{ProtocolNumber6} - }) +// NewProtocol6 returns an ICMPv6 transport protocol. +func NewProtocol6() stack.TransportProtocol { + return &protocol{ProtocolNumber6} } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index 783c21e6b..a2512d666 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -20,13 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -type factory struct{} +// EndpointFactory implements stack.UnassociatedEndpointFactory. +type EndpointFactory struct{} // NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. -func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } - -func init() { - stack.RegisterUnassociatedFactory(factory{}) -} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a13b2022..d5d8ab96a 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -14,7 +14,7 @@ // Package tcp contains the implementation of the TCP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the +// activated on the stack by passing tcp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing tcp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -34,9 +34,6 @@ import ( ) const ( - // ProtocolName is the string representation of the tcp protocol name. - ProtocolName = "tcp" - // ProtocolNumber is the tcp protocol number. ProtocolNumber = header.TCPProtocolNumber @@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{ - sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, - congestionControl: ccReno, - availableCongestionControl: []string{ccReno, ccCubic}, - } - }) +// NewProtocol returns a TCP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{ + sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, + } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7fa5cfb6e..2be094876 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2873,7 +2873,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { } func TestDefaultBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2919,7 +2922,10 @@ func TestDefaultBufferSizes(t *testing.T) { } func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2965,10 +2971,13 @@ func TestMinMaxBufferSizes(t *testing.T) { } func makeStack() (*stack.Stack, *tcpip.Error) { - s := stack.New([]string{ - ipv4.ProtocolName, - ipv6.ProtocolName, - }, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + ipv6.NewProtocol(), + }, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) id := loopback.New() if testing.Verbose() { diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 78eff5c3a..d3f1d2cdf 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -137,7 +137,10 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Allow minimum send/receive buffer sizes to be 1 during tests. if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 068d9a272..f5cc932dd 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -14,7 +14,7 @@ // Package udp contains the implementation of the UDP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.ProtocolName (or "udp") as one of the +// activated on the stack by passing udp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing udp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -30,9 +30,6 @@ import ( ) const ( - // ProtocolName is the string representation of the udp protocol name. - ProtocolName = "udp" - // ProtocolNumber is the udp protocol number. ProtocolNumber = header.UDPProtocolNumber ) @@ -182,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{} - }) +// NewProtocol returns a UDP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c6deab892..2ec27be4d 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -274,7 +274,10 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 54d1ab129..d90381c0f 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -80,6 +80,7 @@ go_library( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/urpc", diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index d824d7dc5..adf345490 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -54,6 +54,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/runsc/boot/filter" @@ -911,15 +912,17 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { case NetworkNone, NetworkSandbox: // NetworkNone sets up loopback using netstack. - netProtos := []string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName} - protoNames := []string{tcp.ProtocolName, udp.ProtocolName, icmp.ProtocolName4} - s := epsocket.Stack{stack.New(netProtos, protoNames, stack.Options{ - Clock: clock, - Stats: epsocket.Metrics, - HandleLocal: true, + netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} + transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} + s := epsocket.Stack{stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + Clock: clock, + Stats: epsocket.Metrics, + HandleLocal: true, // Enable raw sockets for users with sufficient // privileges. - Raw: true, + UnassociatedFactory: raw.EndpointFactory{}, })} // Enable SACK Recovery. -- cgit v1.2.3