summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go3
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go119
3 files changed, 122 insertions, 2 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 4b21ee79c..5e7f10f4b 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -32,12 +32,14 @@ go_test(
"ipv4_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/internal/testutil",
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index cabe274d6..8a2140ebe 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -899,10 +899,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {
e.mu.Lock()
- defer e.mu.Unlock()
-
e.disableLocked()
e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
e.protocol.forgetEndpoint(e.nic.ID())
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 26d9696d7..cfed241bf 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -26,12 +26,14 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
@@ -2985,3 +2987,120 @@ func TestPacketQueing(t *testing.T) {
})
}
}
+
+// TestCloseLocking test that lock ordering is followed when closing an
+// endpoint.
+func TestCloseLocking(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
+
+ iterations = 1000
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ // Perform NAT so that the endoint tries to search for a sibling endpoint
+ // which ends up taking the protocol and endpoint lock (in that order).
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 1,
+ stack.Forward: stack.HookUnset,
+ stack.Output: 2,
+ stack.Postrouting: 3,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 1,
+ stack.Forward: stack.HookUnset,
+ stack.Output: 2,
+ stack.Postrouting: 3,
+ },
+ }
+ if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil {
+ t.Fatalf("s.IPTables().ReplaceTable(...): %s", err)
+ }
+
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ }})
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53}
+ if err := ep.Connect(addr); err != nil {
+ t.Errorf("ep.Connect(%#v): %s", addr, err)
+ }
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ // Writing packets should trigger NAT which requires the stack to search the
+ // protocol for network endpoints with the destination address.
+ //
+ // Creating and removing interfaces should modify the protocol and endpoint
+ // which requires taking the locks of each.
+ //
+ // We expect the protocol > endpoint lock ordering to be followed here.
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ data := []byte{1, 2, 3, 4}
+
+ for i := 0; i < iterations; i++ {
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Errorf("ep.Write(_, _): %s", err)
+ return
+ } else if want := int64(len(data)); n != want {
+ t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
+ return
+ }
+ }
+ }()
+ go func() {
+ defer wg.Done()
+
+ for i := 0; i < iterations; i++ {
+ if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
+ return
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Errorf("RemoveNIC(%d): %s", nicID2, err)
+ return
+ }
+ }
+ }()
+}