summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go81
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go15
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go6
7 files changed, 91 insertions, 23 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d87bfe048..b3b7a1d0e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -46,17 +46,23 @@ const (
stateClosed
)
-// endpoint represents an ICMP (ping) endpoint. This struct serves as the
-// interface between users of the endpoint and the protocol implementation; it
-// is legal to have concurrent goroutines make calls into the endpoint, they
-// are properly synchronized.
+// endpoint represents an ICMP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// +stateify savable
type endpoint struct {
- // The following fields are initialized at creation time and do not
- // change throughout the lifetime of the endpoint.
+ // The following fields are initialized at creation time and are
+ // immutable.
stack *stack.Stack `state:"manual"`
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
+ // raw indicates whether the endpoint is intended for use by a raw
+ // socket, which returns the network layer header along with the
+ // payload. It is immutable.
+ raw bool
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -80,15 +86,26 @@ type endpoint struct {
route stack.Route `state:"manual"`
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) {
+ e := &endpoint{
stack: stack,
netProto: netProto,
transProto: transProto,
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ raw: raw,
+ }
+
+ // Raw endpoints must be immediately bound because they receive all
+ // ICMP traffic starting from when they're created via socket().
+ if raw {
+ if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
+ return nil, err
+ }
}
+
+ return e, nil
}
// Close puts the endpoint in a closed state and frees all resources
@@ -98,7 +115,11 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ if e.raw {
+ e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e)
+ } else {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ }
}
// Close the receive list and drain it.
@@ -285,10 +306,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
switch e.netProto {
case header.IPv4ProtocolNumber:
- err = sendPing4(route, e.id.LocalPort, v)
+ err = e.send4(route, v)
case header.IPv6ProtocolNumber:
- err = sendPing6(route, e.id.LocalPort, v)
+ err = send6(route, e.id.LocalPort, v)
}
if err != nil {
@@ -346,13 +367,19 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
+ if e.raw {
+ hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength()))
+ return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+ }
+
if len(data) < header.ICMPv4EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident. Sequence number is provided by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort)
hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
@@ -371,7 +398,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
-func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
@@ -412,6 +439,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO: We don't yet support connect on a raw socket.
+ if e.raw {
+ return tcpip.ErrNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -515,6 +547,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if e.raw {
+ err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false)
+ return stack.TransportEndpointID{}, err
+ }
+
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
@@ -657,7 +694,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -675,9 +712,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
Addr: id.RemoteAddress,
},
}
- pkt.data = vv.Clone(pkt.views[:])
+
+ if e.raw {
+ combinedVV := netHeader.ToVectorisedView()
+ combinedVV.Append(vv)
+ pkt.data = combinedVV.Clone(pkt.views[:])
+ } else {
+ pkt.data = vv.Clone(pkt.views[:])
+ }
+
e.rcvList.PushBack(pkt)
- e.rcvBufSize += vv.Size()
+ e.rcvBufSize += pkt.data.Size()
pkt.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9f0a2bf71..36b70988a 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -47,6 +47,7 @@ const (
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
+// protocol implements stack.TransportProtocol.
type protocol struct {
number tcpip.TransportProtocolNumber
}
@@ -66,12 +67,22 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
-// NewEndpoint creates a new icmp endpoint.
+// NewEndpoint creates a new icmp endpoint. It implements
+// stack.TransportProtocol.NewEndpoint.
func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue), nil
+ return newEndpoint(stack, netProto, p.number, waiterQueue, false)
+}
+
+// NewRawEndpoint creates a new raw icmp endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue, true)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a8618bb4a..c48a27d8f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1441,7 +1441,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
s := newSegment(r, id, vv)
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2f90839e9..ca53a076f 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -63,7 +63,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 753e1419e..639ad3fae 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -101,6 +101,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN
return newEndpoint(stack, netProto, waiterQueue), nil
}
+// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
// MinimumPacketSize returns the minimum valid tcp packet size.
func (*protocol) MinimumPacketSize() int {
return header.TCPMinimumSize
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 05d35e526..44b9cdf6a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -934,7 +934,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
// Get the header then trim it from the view.
hdr := header.UDP(vv.First())
if int(hdr.Length()) > vv.Size() {
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index b3fbed6e4..616a9f388 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -48,6 +48,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN
return newEndpoint(stack, netProto, waiterQueue), nil
}
+// NewRawEndpoint creates a new raw UDP endpoint. Raw UDP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrUnknownProtocol
+}
+
// MinimumPacketSize returns the minimum valid udp packet size.
func (*protocol) MinimumPacketSize() int {
return header.UDPMinimumSize