summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/link/waitable
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/link/waitable')
-rw-r--r--pkg/tcpip/link/waitable/BUILD30
-rw-r--r--pkg/tcpip/link/waitable/waitable.go149
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go173
3 files changed, 352 insertions, 0 deletions
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
new file mode 100644
index 000000000..0956d2c65
--- /dev/null
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -0,0 +1,30 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "waitable",
+ srcs = [
+ "waitable.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/gate",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "waitable_test",
+ srcs = [
+ "waitable_test.go",
+ ],
+ library = ":waitable",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
new file mode 100644
index 000000000..2b3741276
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -0,0 +1,149 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package waitable provides the implementation of data-link layer endpoints
+// that wrap other endpoints, and can wait for inflight calls to WritePacket or
+// DeliverNetworkPacket to finish (and new ones to be prevented).
+//
+// Waitable endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package waitable
+
+import (
+ "gvisor.dev/gvisor/pkg/gate"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// Endpoint is a waitable link-layer endpoint.
+type Endpoint struct {
+ dispatchGate gate.Gate
+ dispatcher stack.NetworkDispatcher
+
+ writeGate gate.Gate
+ lower stack.LinkEndpoint
+}
+
+// New creates a new waitable link-layer endpoint. It wraps around another
+// endpoint and allows the caller to block new write/dispatch calls and wait for
+// the inflight ones to finish before returning.
+func New(lower stack.LinkEndpoint) *Endpoint {
+ return &Endpoint{
+ lower: lower,
+ }
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
+// It is called by the link-layer endpoint being wrapped when a packet arrives,
+// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
+// been called.
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
+// registers with the lower endpoint as its dispatcher so that "e" is called
+// for inbound packets.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
+// lower endpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just
+// forwards the request to the lower endpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WritePacket(r, gso, protocol, pkt)
+ e.writeGate.Leave()
+ return err
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if !e.writeGate.Enter() {
+ return pkts.Len(), nil
+ }
+
+ n, err := e.lower.WritePackets(r, gso, pkts, protocol)
+ e.writeGate.Leave()
+ return n, err
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WriteRawPacket(vv)
+ e.writeGate.Leave()
+ return err
+}
+
+// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
+// and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitWrite() {
+ e.writeGate.Close()
+}
+
+// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the
+// actual dispatcher, and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitDispatch() {
+ e.dispatchGate.Close()
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
new file mode 100644
index 000000000..54eb5322b
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -0,0 +1,173 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package waitable
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type countedEndpoint struct {
+ dispatchCount int
+ writeCount int
+ attachCount int
+
+ mtu uint32
+ capabilities stack.LinkEndpointCapabilities
+ hdrLen uint16
+ linkAddr tcpip.LinkAddress
+
+ dispatcher stack.NetworkDispatcher
+}
+
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+ e.dispatchCount++
+}
+
+func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.attachCount++
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *countedEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+func (e *countedEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+func (e *countedEndpoint) MaxHeaderLength() uint16 {
+ return e.hdrLen
+}
+
+func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += pkts.Len()
+ return pkts.Len(), nil
+}
+
+func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
+func TestWaitWrite(t *testing.T) {
+ ep := &countedEndpoint{}
+ wep := New(ep)
+
+ // Write and check that it goes through.
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ if want := 1; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on dispatches, then try to write. It must go through.
+ wep.WaitDispatch()
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on writes, then try to write. It must not go through.
+ wep.WaitWrite()
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+}
+
+func TestWaitDispatch(t *testing.T) {
+ ep := &countedEndpoint{}
+ wep := New(ep)
+
+ // Check that attach happens.
+ wep.Attach(ep)
+ if want := 1; ep.attachCount != want {
+ t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want)
+ }
+
+ // Dispatch and check that it goes through.
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
+ if want := 1; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on writes, then try to dispatch. It must go through.
+ wep.WaitWrite()
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on dispatches, then try to dispatch. It must not go through.
+ wep.WaitDispatch()
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+}
+
+func TestOtherMethods(t *testing.T) {
+ const (
+ mtu = 0xdead
+ capabilities = 0xbeef
+ hdrLen = 0x1234
+ linkAddr = "test address"
+ )
+ ep := &countedEndpoint{
+ mtu: mtu,
+ capabilities: capabilities,
+ hdrLen: hdrLen,
+ linkAddr: linkAddr,
+ }
+ wep := New(ep)
+
+ if v := wep.MTU(); v != mtu {
+ t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
+ }
+
+ if v := wep.Capabilities(); v != capabilities {
+ t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities)
+ }
+
+ if v := wep.MaxHeaderLength(); v != hdrLen {
+ t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen)
+ }
+
+ if v := wep.LinkAddress(); v != linkAddr {
+ t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr)
+ }
+}