diff options
author | Ben Burkert <ben@benburkert.com> | 2019-04-18 17:48:54 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-18 17:49:57 -0700 |
commit | cec2cdc12f30e87e5b0f6750fe1c132d89fcfb6d (patch) | |
tree | 76ccb6543b9f59649e31582f2a9337e1643562bd | |
parent | f4d434c18002c96511decf8ff1ebdbede46ca6a1 (diff) |
tcpip/transport/udp: add Forwarder type
Add a UDP forwarder for intercepting and forwarding UDP sessions.
Change-Id: I2d83c900c1931adfc59a532dd4f6b33a0db406c9
PiperOrigin-RevId: 244293576
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet_test.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/forwarder.go | 96 |
3 files changed, 153 insertions, 0 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 05a730a05..ab3da2e4e 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -222,6 +222,62 @@ func TestCloseReaderWithForwarder(t *testing.T) { sender.close() } +func TestUDPForwarder(t *testing.T) { + s, terr := newLoopbackStack() + if terr != nil { + t.Fatalf("newLoopbackStack() = %v", terr) + } + + ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr1 := tcpip.FullAddress{NICID, ip1, 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) + addr2 := tcpip.FullAddress{NICID, ip2, 11311} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + + done := make(chan struct{}) + fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { + defer close(done) + + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + t.Fatalf("r.CreateEndpoint() = %v", err) + } + defer ep.Close() + + c := NewConn(&wq, ep) + + buf := make([]byte, 256) + n, e := c.Read(buf) + if e != nil { + t.Errorf("c.Read() = %v", e) + } + + if _, e := c.Write(buf[:n]); e != nil { + t.Errorf("c.Write() = %v", e) + } + }) + s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) + + c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("NewPacketConn(port 5):", err) + } + + sent := "abc123" + sendAddr := fullToUDPAddr(addr1) + if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { + t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) + } + + buf := make([]byte, 256) + n, recvAddr, err := c2.ReadFrom(buf) + if err != nil || recvAddr.String() != sendAddr.String() { + t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil) + } +} + // TestDeadlineChange tests that changing the deadline affects currently blocked reads. func TestDeadlineChange(t *testing.T) { s, err := newLoopbackStack() diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index d271490c1..361132a25 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -20,6 +20,7 @@ go_library( srcs = [ "endpoint.go", "endpoint_state.go", + "forwarder.go", "protocol.go", "udp_packet_list.go", ], diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go new file mode 100644 index 000000000..d80c47e34 --- /dev/null +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -0,0 +1,96 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Forwarder is a session request forwarder, which allows clients to decide +// what to do with a session request, for example: ignore it, or process it. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +type Forwarder struct { + handler func(*ForwarderRequest) + + stack *stack.Stack +} + +// NewForwarder allocates and initializes a new forwarder. +func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { + return &Forwarder{ + stack: s, + handler: handler, + } +} + +// HandlePacket handles all packets. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + f.handler(&ForwarderRequest{ + stack: f.stack, + route: r, + id: id, + vv: vv, + }) + + return true +} + +// ForwarderRequest represents a session request received by the forwarder and +// passed to the client. Clients may optionally create an endpoint to represent +// it via CreateEndpoint. +type ForwarderRequest struct { + stack *stack.Stack + route *stack.Route + id stack.TransportEndpointID + vv buffer.VectorisedView +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the session request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.id +} + +// CreateEndpoint creates a connected UDP endpoint for the session request. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := newEndpoint(r.stack, r.route.NetProto, queue) + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + ep.Close() + return nil, err + } + + ep.id = r.id + ep.route = r.route.Clone() + ep.dstPort = r.id.RemotePort + ep.regNICID = r.route.NICID() + + ep.state = stateConnected + + ep.rcvMu.Lock() + ep.rcvReady = true + ep.rcvMu.Unlock() + + ep.HandlePacket(r.route, r.id, r.vv) + + return ep, nil +} |