diff options
Diffstat (limited to 'pkg/tcpip/link/waitable')
-rw-r--r-- | pkg/tcpip/link/waitable/BUILD | 30 | ||||
-rw-r--r-- | pkg/tcpip/link/waitable/waitable.go | 149 | ||||
-rw-r--r-- | pkg/tcpip/link/waitable/waitable_test.go | 173 |
3 files changed, 352 insertions, 0 deletions
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD new file mode 100644 index 000000000..0956d2c65 --- /dev/null +++ b/pkg/tcpip/link/waitable/BUILD @@ -0,0 +1,30 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "waitable", + srcs = [ + "waitable.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/gate", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "waitable_test", + srcs = [ + "waitable_test.go", + ], + library = ":waitable", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go new file mode 100644 index 000000000..2b3741276 --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable.go @@ -0,0 +1,149 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package waitable provides the implementation of data-link layer endpoints +// that wrap other endpoints, and can wait for inflight calls to WritePacket or +// DeliverNetworkPacket to finish (and new ones to be prevented). +// +// Waitable endpoints can be used in the networking stack by calling New(eID) to +// create a new endpoint, where eID is the ID of the endpoint being wrapped, +// and then passing it as an argument to Stack.CreateNIC(). +package waitable + +import ( + "gvisor.dev/gvisor/pkg/gate" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// Endpoint is a waitable link-layer endpoint. +type Endpoint struct { + dispatchGate gate.Gate + dispatcher stack.NetworkDispatcher + + writeGate gate.Gate + lower stack.LinkEndpoint +} + +// New creates a new waitable link-layer endpoint. It wraps around another +// endpoint and allows the caller to block new write/dispatch calls and wait for +// the inflight ones to finish before returning. +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, + } +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. +// It is called by the link-layer endpoint being wrapped when a packet arrives, +// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't +// been called. +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + +// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and +// registers with the lower endpoint as its dispatcher so that "e" is called +// for inbound packets. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *Endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the +// lower endpoint. +func (e *Endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the +// request to the lower endpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just +// forwards the request to the lower endpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the +// request to the lower endpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WritePacket(r, gso, protocol, pkt) + e.writeGate.Leave() + return err +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if !e.writeGate.Enter() { + return pkts.Len(), nil + } + + n, err := e.lower.WritePackets(r, gso, pkts, protocol) + e.writeGate.Leave() + return n, err +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WriteRawPacket(vv) + e.writeGate.Leave() + return err +} + +// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, +// and waits for inflight ones to finish before returning. +func (e *Endpoint) WaitWrite() { + e.writeGate.Close() +} + +// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the +// actual dispatcher, and waits for inflight ones to finish before returning. +func (e *Endpoint) WaitDispatch() { + e.dispatchGate.Close() +} + +// Wait implements stack.LinkEndpoint.Wait. +func (e *Endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go new file mode 100644 index 000000000..54eb5322b --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -0,0 +1,173 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package waitable + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type countedEndpoint struct { + dispatchCount int + writeCount int + attachCount int + + mtu uint32 + capabilities stack.LinkEndpointCapabilities + hdrLen uint16 + linkAddr tcpip.LinkAddress + + dispatcher stack.NetworkDispatcher +} + +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { + e.dispatchCount++ +} + +func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.attachCount++ + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *countedEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +func (e *countedEndpoint) MTU() uint32 { + return e.mtu +} + +func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.capabilities +} + +func (e *countedEndpoint) MaxHeaderLength() uint16 { + return e.hdrLen +} + +func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { + e.writeCount++ + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += pkts.Len() + return pkts.Len(), nil +} + +func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + e.writeCount++ + return nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*countedEndpoint) Wait() {} + +func TestWaitWrite(t *testing.T) { + ep := &countedEndpoint{} + wep := New(ep) + + // Write and check that it goes through. + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + if want := 1; ep.writeCount != want { + t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) + } + + // Wait on dispatches, then try to write. It must go through. + wep.WaitDispatch() + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + if want := 2; ep.writeCount != want { + t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) + } + + // Wait on writes, then try to write. It must not go through. + wep.WaitWrite() + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + if want := 2; ep.writeCount != want { + t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) + } +} + +func TestWaitDispatch(t *testing.T) { + ep := &countedEndpoint{} + wep := New(ep) + + // Check that attach happens. + wep.Attach(ep) + if want := 1; ep.attachCount != want { + t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want) + } + + // Dispatch and check that it goes through. + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + if want := 1; ep.dispatchCount != want { + t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) + } + + // Wait on writes, then try to dispatch. It must go through. + wep.WaitWrite() + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + if want := 2; ep.dispatchCount != want { + t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) + } + + // Wait on dispatches, then try to dispatch. It must not go through. + wep.WaitDispatch() + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + if want := 2; ep.dispatchCount != want { + t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) + } +} + +func TestOtherMethods(t *testing.T) { + const ( + mtu = 0xdead + capabilities = 0xbeef + hdrLen = 0x1234 + linkAddr = "test address" + ) + ep := &countedEndpoint{ + mtu: mtu, + capabilities: capabilities, + hdrLen: hdrLen, + linkAddr: linkAddr, + } + wep := New(ep) + + if v := wep.MTU(); v != mtu { + t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) + } + + if v := wep.Capabilities(); v != capabilities { + t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities) + } + + if v := wep.MaxHeaderLength(); v != hdrLen { + t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen) + } + + if v := wep.LinkAddress(); v != linkAddr { + t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr) + } +} |