summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/raw
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2019-07-13 01:17:01 +0000
committergVisor bot <gvisor-bot@google.com>2019-07-13 01:17:01 +0000
commit5ae571182025283ed8272a3e5456d795f4ef0c80 (patch)
tree64d43beec76fe84277ed4d0d47f61a3936be0da8 /pkg/tcpip/transport/raw
parente9e36c40dfab1748fed1ef18270540133c91ee24 (diff)
parent9b4d3280e172063a6563d9e72a75b500442ed9b9 (diff)
Merge 9b4d3280 (automated)
Diffstat (limited to 'pkg/tcpip/transport/raw')
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go100
-rw-r--r--pkg/tcpip/transport/raw/protocol.go32
-rwxr-xr-xpkg/tcpip/transport/raw/raw_state_autogen.go2
3 files changed, 111 insertions, 23 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 42aded77f..a29587658 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -67,6 +67,7 @@ type endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
+ associated bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -97,8 +98,12 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET.
+// TODO(b/129292371): IP_HDRINCL and AF_PACKET.
func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
if netProto != header.IPv4ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -110,6 +115,16 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ associated: associated,
+ }
+
+ // Unassociated endpoints are write-only and users call Write() with IP
+ // headers included. Because they're write-only, We don't need to
+ // register with the stack.
+ if !associated {
+ ep.rcvBufSizeMax = 0
+ ep.waiterQueue = nil
+ return ep, nil
}
if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
@@ -124,7 +139,7 @@ func (ep *endpoint) Close() {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.closed {
+ if ep.closed || !ep.associated {
return
}
@@ -142,8 +157,11 @@ func (ep *endpoint) Close() {
if ep.connected {
ep.route.Release()
+ ep.connected = false
}
+ ep.closed = true
+
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -152,6 +170,10 @@ func (ep *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ if !ep.associated {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
+ }
+
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -192,6 +214,33 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
return 0, nil, tcpip.ErrInvalidEndpointState
}
+ payloadBytes, err := payload.Get(payload.Size())
+ if err != nil {
+ ep.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // If this is an unassociated socket and callee provided a nonzero
+ // destination address, route using that address.
+ if !ep.associated {
+ ip := header.IPv4(payloadBytes)
+ if !ip.IsValid(payload.Size()) {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ dstAddr := ip.DestinationAddress()
+ // Update dstAddr with the address in the IP header, unless
+ // opts.To is set (e.g. if sendto specifies a specific
+ // address).
+ if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
+ opts.To = &tcpip.FullAddress{
+ NIC: 0, // NIC is unset.
+ Addr: dstAddr, // The address from the payload.
+ Port: 0, // There are no ports here.
+ }
+ }
+ }
+
// Did the user caller provide a destination? If not, use the connected
// destination.
if opts.To == nil {
@@ -216,12 +265,12 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
return 0, nil, tcpip.ErrInvalidEndpointState
}
- n, ch, err := ep.finishWrite(payload, savedRoute)
+ n, ch, err := ep.finishWrite(payloadBytes, savedRoute)
ep.mu.Unlock()
return n, ch, err
}
- n, ch, err := ep.finishWrite(payload, &ep.route)
+ n, ch, err := ep.finishWrite(payloadBytes, &ep.route)
ep.mu.RUnlock()
return n, ch, err
}
@@ -248,7 +297,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
return 0, nil, err
}
- n, ch, err := ep.finishWrite(payload, &route)
+ n, ch, err := ep.finishWrite(payloadBytes, &route)
route.Release()
ep.mu.RUnlock()
return n, ch, err
@@ -256,7 +305,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
// We may need to resolve the route (match a link layer address to the
// network address). If that requires blocking (e.g. to use ARP),
// return a channel on which the caller can wait.
@@ -269,13 +318,14 @@ func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uint
}
}
- payloadBytes, err := payload.Get(payload.Size())
- if err != nil {
- return 0, nil, err
- }
-
switch ep.netProto {
case header.IPv4ProtocolNumber:
+ if !ep.associated {
+ if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
+ return 0, nil, err
+ }
+ break
+ }
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
return 0, nil, err
@@ -335,15 +385,17 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
defer route.Release()
- // Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
- return err
+ if ep.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+ ep.registeredNIC = nic
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- // Save the route and NIC we've connected via.
+ // Save the route we've connected via.
ep.route = route.Clone()
- ep.registeredNIC = nic
ep.connected = true
return nil
@@ -386,14 +438,16 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- // Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
- return err
+ if ep.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+ ep.registeredNIC = addr.NIC
+ ep.boundNIC = addr.NIC
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- ep.registeredNIC = addr.NIC
- ep.boundNIC = addr.NIC
ep.boundAddr = addr.Addr
ep.bound = true
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
new file mode 100644
index 000000000..783c21e6b
--- /dev/null
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type factory struct{}
+
+// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
+func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
+}
+
+func init() {
+ stack.RegisterUnassociatedFactory(factory{})
+}
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
index bdbf6207a..947fcc1bb 100755
--- a/pkg/tcpip/transport/raw/raw_state_autogen.go
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -32,6 +32,7 @@ func (x *endpoint) save(m state.Map) {
m.Save("netProto", &x.netProto)
m.Save("transProto", &x.transProto)
m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("associated", &x.associated)
m.Save("rcvList", &x.rcvList)
m.Save("rcvBufSize", &x.rcvBufSize)
m.Save("rcvClosed", &x.rcvClosed)
@@ -48,6 +49,7 @@ func (x *endpoint) load(m state.Map) {
m.Load("netProto", &x.netProto)
m.Load("transProto", &x.transProto)
m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("associated", &x.associated)
m.Load("rcvList", &x.rcvList)
m.Load("rcvBufSize", &x.rcvBufSize)
m.Load("rcvClosed", &x.rcvClosed)