summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/link/waitable
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/link/waitable')
-rw-r--r--pkg/tcpip/link/waitable/BUILD12
-rw-r--r--pkg/tcpip/link/waitable/waitable.go38
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go39
3 files changed, 61 insertions, 28 deletions
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 0746dc8ec..ee84c3d96 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,14 +7,12 @@ go_library(
srcs = [
"waitable.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/gate",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
@@ -25,10 +22,11 @@ go_test(
srcs = [
"waitable_test.go",
],
- embed = [":waitable"],
+ library = ":waitable",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a04fc1062..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -50,12 +51,21 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if !e.dispatchGate.Enter() {
return
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader)
+ e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
e.dispatchGate.Leave()
}
@@ -99,12 +109,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
- err := e.lower.WritePacket(r, gso, hdr, payload, protocol)
+ err := e.lower.WritePacket(r, gso, protocol, pkt)
e.writeGate.Leave()
return err
}
@@ -112,23 +122,23 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
if !e.writeGate.Enter() {
- return len(hdrs), nil
+ return pkts.Len(), nil
}
- n, err := e.lower.WritePackets(r, gso, hdrs, payload, protocol)
+ n, err := e.lower.WritePackets(r, gso, pkts, protocol)
e.writeGate.Leave()
return n, err
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
- err := e.lower.WriteRawPacket(packet)
+ err := e.lower.WriteRawPacket(vv)
e.writeGate.Leave()
return err
}
@@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() {
// Wait implements stack.LinkEndpoint.Wait.
func (e *Endpoint) Wait() {}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 5f0f8fa2d..94827fc56 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -35,10 +36,14 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -65,45 +70,55 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
e.writeCount++
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- e.writeCount += len(hdrs)
- return len(hdrs), nil
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += pkts.Len()
+ return pkts.Len(), nil
}
-func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
e.writeCount++
return nil
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unimplemented")
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -120,21 +135,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}