summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBen Burkert <ben@benburkert.com>2019-04-18 17:48:54 -0700
committerShentubot <shentubot@google.com>2019-04-18 17:49:57 -0700
commitcec2cdc12f30e87e5b0f6750fe1c132d89fcfb6d (patch)
tree76ccb6543b9f59649e31582f2a9337e1643562bd
parentf4d434c18002c96511decf8ff1ebdbede46ca6a1 (diff)
tcpip/transport/udp: add Forwarder type
Add a UDP forwarder for intercepting and forwarding UDP sessions. Change-Id: I2d83c900c1931adfc59a532dd4f6b33a0db406c9 PiperOrigin-RevId: 244293576
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go56
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go96
3 files changed, 153 insertions, 0 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 05a730a05..ab3da2e4e 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -222,6 +222,62 @@ func TestCloseReaderWithForwarder(t *testing.T) {
sender.close()
}
+func TestUDPForwarder(t *testing.T) {
+ s, terr := newLoopbackStack()
+ if terr != nil {
+ t.Fatalf("newLoopbackStack() = %v", terr)
+ }
+
+ ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr1 := tcpip.FullAddress{NICID, ip1, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
+ addr2 := tcpip.FullAddress{NICID, ip2, 11311}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+
+ done := make(chan struct{})
+ fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
+ defer close(done)
+
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+
+ c := NewConn(&wq, ep)
+
+ buf := make([]byte, 256)
+ n, e := c.Read(buf)
+ if e != nil {
+ t.Errorf("c.Read() = %v", e)
+ }
+
+ if _, e := c.Write(buf[:n]); e != nil {
+ t.Errorf("c.Write() = %v", e)
+ }
+ })
+ s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
+
+ c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("NewPacketConn(port 5):", err)
+ }
+
+ sent := "abc123"
+ sendAddr := fullToUDPAddr(addr1)
+ if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
+ t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
+ }
+
+ buf := make([]byte, 256)
+ n, recvAddr, err := c2.ReadFrom(buf)
+ if err != nil || recvAddr.String() != sendAddr.String() {
+ t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil)
+ }
+}
+
// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
func TestDeadlineChange(t *testing.T) {
s, err := newLoopbackStack()
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index d271490c1..361132a25 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -20,6 +20,7 @@ go_library(
srcs = [
"endpoint.go",
"endpoint_state.go",
+ "forwarder.go",
"protocol.go",
"udp_packet_list.go",
],
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
new file mode 100644
index 000000000..d80c47e34
--- /dev/null
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -0,0 +1,96 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a session request forwarder, which allows clients to decide
+// what to do with a session request, for example: ignore it, or process it.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ handler func(*ForwarderRequest)
+
+ stack *stack.Stack
+}
+
+// NewForwarder allocates and initializes a new forwarder.
+func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
+ return &Forwarder{
+ stack: s,
+ handler: handler,
+ }
+}
+
+// HandlePacket handles all packets.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ f.handler(&ForwarderRequest{
+ stack: f.stack,
+ route: r,
+ id: id,
+ vv: vv,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a session request received by the forwarder and
+// passed to the client. Clients may optionally create an endpoint to represent
+// it via CreateEndpoint.
+type ForwarderRequest struct {
+ stack *stack.Stack
+ route *stack.Route
+ id stack.TransportEndpointID
+ vv buffer.VectorisedView
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the session request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.id
+}
+
+// CreateEndpoint creates a connected UDP endpoint for the session request.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := newEndpoint(r.stack, r.route.NetProto, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ ep.id = r.id
+ ep.route = r.route.Clone()
+ ep.dstPort = r.id.RemotePort
+ ep.regNICID = r.route.NICID()
+
+ ep.state = stateConnected
+
+ ep.rcvMu.Lock()
+ ep.rcvReady = true
+ ep.rcvMu.Unlock()
+
+ ep.HandlePacket(r.route, r.id, r.vv)
+
+ return ep, nil
+}