diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/netfilter/netfilter.go | 36 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 31 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tmutex/BUILD | 17 | ||||
-rw-r--r-- | pkg/tmutex/tmutex.go | 81 | ||||
-rw-r--r-- | pkg/tmutex/tmutex_test.go | 258 |
14 files changed, 98 insertions, 439 deletions
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 47ff48c00..66015e2bc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -144,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen } func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - ipt := stk.IPTables() - table, ok := ipt.Tables[tablename.String()] + table, ok := stk.IPTables().GetTable(tablename.String()) if !ok { return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stk *stack.Stack) { - ipt := stack.DefaultTables() - - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func FillIPTablesMetadata(stk *stack.Stack) { + stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { + // In order to fill in the metadata, we have to translate ipt from its + // netstack format to Linux's giant-binary-blob format. + for name, table := range tables { + _, metadata, err := convertNetstackToBinary(name, table) + if err != nil { + panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + } + table.SetMetadata(metadata) + tables[name] = table } - table.SetMetadata(metadata) - ipt.Tables[name] = table - } - - stk.SetIPTables(ipt) + }) } // convertNetstackToBinary converts the iptables as stored in netstack to the @@ -573,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, NumEntries: replace.NumEntries, Size: replace.Size, }) - ipt.Tables[replace.Name.String()] = table - stk.SetIPTables(ipt) + stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f5fa18136..9b44c2b89 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (stack.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func (s *Stack) FillIPTablesMetadata() { + netfilter.FillIPTablesMetadata(s.Stack) } // Resume implements inet.Stack.Resume. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index d989dbe91..4e9b404c8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -43,11 +43,11 @@ const HookUnset = -1 // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { +func DefaultTables() *IPTables { // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for // iotas. - return IPTables{ - Tables: map[string]Table{ + return &IPTables{ + tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -106,7 +106,7 @@ func DefaultTables() IPTables { UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ + priorities: map[Hook][]string{ Input: []string{TablenameNat, TablenameFilter}, Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, @@ -158,6 +158,36 @@ func EmptyNatTable() Table { } } +// GetTable returns table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + it.mu.RLock() + defer it.mu.RUnlock() + t, ok := it.tables[name] + return t, ok +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) { + it.mu.Lock() + defer it.mu.Unlock() + it.tables[name] = table +} + +// ModifyTables acquires write-lock and calls fn with internal name-to-table +// map. This function can be used to update multiple tables atomically. +func (it *IPTables) ModifyTables(fn func(map[string]Table)) { + it.mu.Lock() + defer it.mu.Unlock() + fn(it.tables) +} + +// GetPriorities returns slice of priorities associated with hook. +func (it *IPTables) GetPriorities(hook Hook) []string { + it.mu.RLock() + defer it.mu.RUnlock() + return it.priorities[hook] +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -184,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr it.connections.HandlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + for _, tablename := range it.GetPriorities(hook) { + table, _ := it.GetTable(tablename) ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index af72b9c46..4a6a5c6f1 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -16,6 +16,7 @@ package stack import ( "strings" + "sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,17 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables and priorities. + mu sync.RWMutex - // Priorities maps each hook to a list of table names. The order of the + // tables maps table names to tables. User tables have arbitrary names. mu + // needs to be locked for accessing. + tables map[string]Table + + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities map[Hook][]string connections ConnTrackTable } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8af06cb9a..294ce8775 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -424,12 +424,8 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -676,6 +672,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, @@ -1741,18 +1738,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 29ff68df3..3bc72bc19 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -140,11 +140,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index bab2d63ae..baf08eda6 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -132,11 +132,6 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (stack.IPTables, error) { - return ep.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 25a17940d..21c34fac2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !e.associated { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d048ef90c..edca98160 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1172,11 +1172,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 3a19c4468..acacb42e4 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -618,6 +618,20 @@ func (s *sender) splitSeg(seg *segment, size int) { nSeg.data.TrimFront(size) nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) s.writeList.InsertAfter(seg, nSeg) + + // The segment being split does not carry PUSH flag because it is + // followed by the newly split segment. + // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered + // segment (i.e., when there is no more queued data to be sent). + // Linux removes PSH flag only when the segment is being split over MSS + // and retains it when we are splitting the segment over lack of sender + // window space. + // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() + if seg.data.Size() > s.maxPayloadSize { + seg.flags ^= header.TCPFlagPsh + } + seg.data.CapLength(size) } @@ -739,7 +753,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(seg.sequenceNumber.Size(end)) + available := int(s.sndNxt.Size(end)) if available > limit { available = limit } @@ -782,8 +796,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // sent all at once. return false } - if atomic.LoadUint32(&s.ep.cork) != 0 { - // Hold back the segment until full. + // With TCP_CORK, hold back until minimum of the available + // send space and MSS. + // TODO(gvisor.dev/issue/2833): Drain the held segments after a + // timeout. + if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { return false } } @@ -843,9 +860,17 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if available == 0 { return false } + + // The segment size limit is computed as a function of sender congestion + // window and MSS. When sender congestion window is > 1, this limit can + // be larger than MSS. Ensure that the currently available send space + // is not greater than minimum of this limit and MSS. if available > limit { available = limit } + if available > s.maxPayloadSize { + available = s.maxPayloadSize + } if seg.data.Size() > available { s.splitSeg(seg, available) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 79faa7869..663af8fec 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -247,11 +247,6 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD deleted file mode 100644 index 2dcba84ae..000000000 --- a/pkg/tmutex/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tmutex", - srcs = ["tmutex.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "tmutex_test", - size = "medium", - srcs = ["tmutex_test.go"], - library = ":tmutex", - deps = ["//pkg/sync"], -) diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go deleted file mode 100644 index c4685020d..000000000 --- a/pkg/tmutex/tmutex.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tmutex provides the implementation of a mutex that implements an -// efficient TryLock function in addition to Lock and Unlock. -package tmutex - -import ( - "sync/atomic" -) - -// Mutex is a mutual exclusion primitive that implements TryLock in addition -// to Lock and Unlock. -type Mutex struct { - v int32 - ch chan struct{} -} - -// Init initializes the mutex. -func (m *Mutex) Init() { - m.v = 1 - m.ch = make(chan struct{}, 1) -} - -// Lock acquires the mutex. If it is currently held by another goroutine, Lock -// will wait until it has a chance to acquire it. -func (m *Mutex) Lock() { - // Uncontended case. - if atomic.AddInt32(&m.v, -1) == 0 { - return - } - - for { - // Try to acquire the mutex again, at the same time making sure - // that m.v is negative, which indicates to the owner of the - // lock that it is contended, which will force it to try to wake - // someone up when it releases the mutex. - if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 { - return - } - - // Wait for the mutex to be released before trying again. - <-m.ch - } -} - -// TryLock attempts to acquire the mutex without blocking. If the mutex is -// currently held by another goroutine, it fails to acquire it and returns -// false. -func (m *Mutex) TryLock() bool { - v := atomic.LoadInt32(&m.v) - if v <= 0 { - return false - } - return atomic.CompareAndSwapInt32(&m.v, 1, 0) -} - -// Unlock releases the mutex. -func (m *Mutex) Unlock() { - if atomic.SwapInt32(&m.v, 1) == 0 { - // There were no pending waiters. - return - } - - // Wake some waiter up. - select { - case m.ch <- struct{}{}: - default: - } -} diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go deleted file mode 100644 index 05540696a..000000000 --- a/pkg/tmutex/tmutex_test.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmutex - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestBasicLock(t *testing.T) { - var m Mutex - m.Init() - - m.Lock() - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - ch <- struct{}{} - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } - - // Make sure we can lock and unlock again. - m.Lock() - m.Unlock() -} - -func TestTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Try to lock. It should succeed. - if !m.TryLock() { - t.Fatalf("TryLock failed on unlocked mutex") - } - - // Try to lock again, it should now fail. - if m.TryLock() { - t.Fatalf("TryLock succeeded on locked mutex") - } - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } -} - -func TestMutualExclusion(t *testing.T) { - var m Mutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes to it to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestMutualExclusionWithTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Similar to the previous, with the addition of some goroutines that - // only increment the count if TryLock succeeds. - const gr = 1000 - const iters = 100000 - total := int64(gr * iters) - var tryTotal int64 - v := int64(0) - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(2) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - go func() { - local := int64(0) - for j := 0; j < iters; j++ { - if m.TryLock() { - v++ - m.Unlock() - local++ - } - } - atomic.AddInt64(&tryTotal, local) - wg.Done() - }() - } - - wg.Wait() - - t.Logf("tryTotal = %d", tryTotal) - total += tryTotal - - if v != total { - t.Fatalf("Bad count: got %v, want %v", v, total) - } -} - -// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following -// differences: -// -// - The number of goroutines is variable, with the maximum value depending on -// GOMAXPROCS. -// -// - The number of iterations per benchmark is controlled by the benchmarking -// framework. -// -// - Care is taken to ensure that all goroutines participating in the benchmark -// have been created before the benchmark begins. -func BenchmarkTmutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m Mutex - m.Init() - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} - -// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as -// a comparison point. -func BenchmarkSyncMutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m sync.Mutex - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} |