diff options
author | Bruno Dal Bo <brunodalbo@google.com> | 2020-06-22 10:30:21 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-06-22 10:31:45 -0700 |
commit | 282a6aea1b375d447fdf502c6660e92eb5e19cd4 (patch) | |
tree | 3c90df2c74369144427c289b3e3b320956e7b52d | |
parent | a480b4faf4befb029bf905fdb604996c8312a6a2 (diff) |
Extract common nested LinkEndpoint pattern
... and unify logic for detached netsted endpoints.
sniffer.go caused crashes if a packet delivery is attempted when the dispatcher
is nil.
Extracted the endpoint nesting logic into a common composable type so it can be
used by the Fuchsia Netstack (the pattern is widespread there).
PiperOrigin-RevId: 317682842
-rw-r--r-- | pkg/tcpip/link/nested/BUILD | 31 | ||||
-rw-r--r-- | pkg/tcpip/link/nested/nested.go | 131 | ||||
-rw-r--r-- | pkg/tcpip/link/nested/nested_test.go | 105 | ||||
-rw-r--r-- | pkg/tcpip/link/sniffer/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/link/sniffer/sniffer.go | 75 |
5 files changed, 285 insertions, 58 deletions
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD new file mode 100644 index 000000000..bdd5276ad --- /dev/null +++ b/pkg/tcpip/link/nested/BUILD @@ -0,0 +1,31 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "nested", + srcs = [ + "nested.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "nested_test", + size = "small", + srcs = [ + "nested_test.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go new file mode 100644 index 000000000..2998f9c4f --- /dev/null +++ b/pkg/tcpip/link/nested/nested.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nested provides helpers to implement the pattern of nested +// stack.LinkEndpoints. +package nested + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher +// that can be used to implement nesting safely by providing lifecycle +// concurrency guards. +// +// See the tests in this package for example usage. +type Endpoint struct { + child stack.LinkEndpoint + embedder stack.NetworkDispatcher + + // mu protects dispatcher. + mu sync.RWMutex + dispatcher stack.NetworkDispatcher +} + +var _ stack.GSOEndpoint = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.NetworkDispatcher = (*Endpoint)(nil) + +// Init initializes a nested.Endpoint that uses embedder as the dispatcher for +// child on Attach. +// +// See the tests in this package for example usage. +func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) { + e.child = child + e.embedder = embedder +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverNetworkPacket(remote, local, protocol, pkt) + } +} + +// Attach implements stack.LinkEndpoint. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + e.dispatcher = dispatcher + e.mu.Unlock() + // If we're attaching to a valid dispatcher, pass embedder as the dispatcher + // to our child, otherwise detach the child by giving it a nil dispatcher. + var pass stack.NetworkDispatcher + if dispatcher != nil { + pass = e.embedder + } + e.child.Attach(pass) +} + +// IsAttached implements stack.LinkEndpoint. +func (e *Endpoint) IsAttached() bool { + e.mu.RLock() + isAttached := e.dispatcher != nil + e.mu.RUnlock() + return isAttached +} + +// MTU implements stack.LinkEndpoint. +func (e *Endpoint) MTU() uint32 { + return e.child.MTU() +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.child.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return e.child.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.child.LinkAddress() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + return e.child.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + return e.child.WritePackets(r, gso, pkts, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return e.child.WriteRawPacket(vv) +} + +// Wait implements stack.LinkEndpoint. +func (e *Endpoint) Wait() { + e.child.Wait() +} + +// GSOMaxSize implements stack.GSOEndpoint. +func (e *Endpoint) GSOMaxSize() uint32 { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.GSOMaxSize() + } + return 0 +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go new file mode 100644 index 000000000..c1a219f02 --- /dev/null +++ b/pkg/tcpip/link/nested/nested_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nested_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type parentEndpoint struct { + nested.Endpoint +} + +var _ stack.LinkEndpoint = (*parentEndpoint)(nil) +var _ stack.NetworkDispatcher = (*parentEndpoint)(nil) + +type childEndpoint struct { + stack.LinkEndpoint + dispatcher stack.NetworkDispatcher +} + +var _ stack.LinkEndpoint = (*childEndpoint)(nil) + +func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + c.dispatcher = dispatcher +} + +func (c *childEndpoint) IsAttached() bool { + return c.dispatcher != nil +} + +type counterDispatcher struct { + count int +} + +var _ stack.NetworkDispatcher = (*counterDispatcher)(nil) + +func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + d.count++ +} + +func TestNestedLinkEndpoint(t *testing.T) { + const emptyAddress = tcpip.LinkAddress("") + + var ( + childEP childEndpoint + nestedEP parentEndpoint + disp counterDispatcher + ) + nestedEP.Endpoint.Init(&childEP, &nestedEP) + + if childEP.IsAttached() { + t.Error("On init, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("On init, nestedEP.IsAttached() = true, want = false") + } + + nestedEP.Attach(&disp) + if disp.count != 0 { + t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count) + } + if !childEP.IsAttached() { + t.Error("After attach, childEP.IsAttached() = false, want = true") + } + if !nestedEP.IsAttached() { + t.Error("After attach, nestedEP.IsAttached() = false, want = true") + } + + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + if disp.count != 1 { + t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) + } + + nestedEP.Attach(nil) + if childEP.IsAttached() { + t.Error("After detach, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("After detach, nestedEP.IsAttached() = true, want = false") + } + + disp.count = 0 + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + if disp.count != 0 { + t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) + } + +} diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 230a8d53a..7cbc305e7 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index f2e47b6a7..d9cd4e83a 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -48,18 +49,21 @@ var LogPackets uint32 = 1 var LogPacketsToPCAP uint32 = 1 type endpoint struct { - dispatcher stack.NetworkDispatcher - lower stack.LinkEndpoint + nested.Endpoint writer io.Writer maxPCAPLen uint32 } +var _ stack.GSOEndpoint = (*endpoint)(nil) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.NetworkDispatcher = (*endpoint)(nil) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - return &endpoint{ - lower: lower, - } + sniffer := &endpoint{} + sniffer.Endpoint.Init(lower, sniffer) + return sniffer } func zoneOffset() (int32, error) { @@ -103,11 +107,12 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( if err := writePCAPHeader(writer, snapLen); err != nil { return nil, err } - return &endpoint{ - lower: lower, + sniffer := &endpoint{ writer: writer, maxPCAPLen: snapLen, - }, nil + } + sniffer.Endpoint.Init(lower, sniffer) + return sniffer, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is @@ -115,50 +120,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dumpPacket("recv", nil, protocol, pkt) - e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) -} - -// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher -// and registers with the lower endpoint as its dispatcher so that "e" is called -// for inbound packets. -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - e.lower.Attach(e) -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the -// lower endpoint. -func (e *endpoint) MTU() uint32 { - return e.lower.MTU() -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the -// request to the lower endpoint. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.lower.Capabilities() -} - -// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards -// the request to the lower endpoint. -func (e *endpoint) MaxHeaderLength() uint16 { - return e.lower.MaxHeaderLength() -} - -func (e *endpoint) LinkAddress() tcpip.LinkAddress { - return e.lower.LinkAddress() -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.lower.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 + e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { @@ -203,7 +165,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.dumpPacket("send", gso, protocol, pkt) - return e.lower.WritePacket(r, gso, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by @@ -213,7 +175,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.dumpPacket("send", gso, protocol, pkt) } - return e.lower.WritePackets(r, gso, pkts, protocol) + return e.Endpoint.WritePackets(r, gso, pkts, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. @@ -221,12 +183,9 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ Data: vv, }) - return e.lower.WriteRawPacket(vv) + return e.Endpoint.WriteRawPacket(vv) } -// Wait implements stack.LinkEndpoint.Wait. -func (e *endpoint) Wait() { e.lower.Wait() } - func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 |