diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 119 |
3 files changed, 122 insertions, 2 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4b21ee79c..5e7f10f4b 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -32,12 +32,14 @@ go_test( "ipv4_test.go", ], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/internal/testutil", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index cabe274d6..8a2140ebe 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -899,10 +899,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() { e.mu.Lock() - defer e.mu.Unlock() - e.disableLocked() e.mu.addressableEndpointState.Cleanup() + e.mu.Unlock() e.protocol.forgetEndpoint(e.nic.ID()) } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 26d9696d7..cfed241bf 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -26,12 +26,14 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" @@ -2985,3 +2987,120 @@ func TestPacketQueing(t *testing.T) { }) } } + +// TestCloseLocking test that lock ordering is followed when closing an +// endpoint. +func TestCloseLocking(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + src = tcpip.Address("\x10\x00\x00\x01") + dst = tcpip.Address("\x10\x00\x00\x02") + + iterations = 1000 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + // Perform NAT so that the endoint tries to search for a sibling endpoint + // which ends up taking the protocol and endpoint lock (in that order). + table := stack.Table{ + Rules: []stack.Rule{ + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 1, + stack.Forward: stack.HookUnset, + stack.Output: 2, + stack.Postrouting: 3, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 1, + stack.Forward: stack.HookUnset, + stack.Output: 2, + stack.Postrouting: 3, + }, + } + if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil { + t.Fatalf("s.IPTables().ReplaceTable(...): %s", err) + } + + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }}) + + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53} + if err := ep.Connect(addr); err != nil { + t.Errorf("ep.Connect(%#v): %s", addr, err) + } + + var wg sync.WaitGroup + defer wg.Wait() + + // Writing packets should trigger NAT which requires the stack to search the + // protocol for network endpoints with the destination address. + // + // Creating and removing interfaces should modify the protocol and endpoint + // which requires taking the locks of each. + // + // We expect the protocol > endpoint lock ordering to be followed here. + wg.Add(2) + go func() { + defer wg.Done() + + data := []byte{1, 2, 3, 4} + + for i := 0; i < iterations; i++ { + var r bytes.Reader + r.Reset(data) + if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Errorf("ep.Write(_, _): %s", err) + return + } else if want := int64(len(data)); n != want { + t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) + return + } + } + }() + go func() { + defer wg.Done() + + for i := 0; i < iterations; i++ { + if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + t.Errorf("CreateNIC(%d, _): %s", nicID2, err) + return + } + if err := s.RemoveNIC(nicID2); err != nil { + t.Errorf("RemoveNIC(%d): %s", nicID2, err) + return + } + } + }() +} |