diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 81 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 6 |
7 files changed, 91 insertions, 23 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d87bfe048..b3b7a1d0e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -46,17 +46,23 @@ const ( stateClosed ) -// endpoint represents an ICMP (ping) endpoint. This struct serves as the -// interface between users of the endpoint and the protocol implementation; it -// is legal to have concurrent goroutines make calls into the endpoint, they -// are properly synchronized. +// endpoint represents an ICMP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// +stateify savable type endpoint struct { - // The following fields are initialized at creation time and do not - // change throughout the lifetime of the endpoint. + // The following fields are initialized at creation time and are + // immutable. stack *stack.Stack `state:"manual"` netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + // raw indicates whether the endpoint is intended for use by a raw + // socket, which returns the network layer header along with the + // payload. It is immutable. + raw bool // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -80,15 +86,26 @@ type endpoint struct { route stack.Route `state:"manual"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - return &endpoint{ +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) { + e := &endpoint{ stack: stack, netProto: netProto, transProto: transProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + raw: raw, + } + + // Raw endpoints must be immediately bound because they receive all + // ICMP traffic starting from when they're created via socket(). + if raw { + if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil { + return nil, err + } } + + return e, nil } // Close puts the endpoint in a closed state and frees all resources @@ -98,7 +115,11 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + if e.raw { + e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e) + } else { + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + } } // Close the receive list and drain it. @@ -285,10 +306,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c switch e.netProto { case header.IPv4ProtocolNumber: - err = sendPing4(route, e.id.LocalPort, v) + err = e.send4(route, v) case header.IPv6ProtocolNumber: - err = sendPing6(route, e.id.LocalPort, v) + err = send6(route, e.id.LocalPort, v) } if err != nil { @@ -346,13 +367,19 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { + if e.raw { + hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength())) + return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + } + if len(data) < header.ICMPv4EchoMinimumSize { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort) hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) @@ -371,7 +398,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) } -func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -412,6 +439,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // TODO: We don't yet support connect on a raw socket. + if e.raw { + return tcpip.ErrNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -515,6 +547,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if e.raw { + err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false) + return stack.TransportEndpointID{}, err + } + if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -657,7 +694,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -675,9 +712,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + + if e.raw { + combinedVV := netHeader.ToVectorisedView() + combinedVV.Append(vv) + pkt.data = combinedVV.Clone(pkt.views[:]) + } else { + pkt.data = vv.Clone(pkt.views[:]) + } + e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + e.rcvBufSize += pkt.data.Size() pkt.timestamp = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9f0a2bf71..36b70988a 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -47,6 +47,7 @@ const ( ProtocolNumber6 = header.ICMPv6ProtocolNumber ) +// protocol implements stack.TransportProtocol. type protocol struct { number tcpip.TransportProtocolNumber } @@ -66,12 +67,22 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { panic(fmt.Sprint("unknown protocol number: ", p.number)) } -// NewEndpoint creates a new icmp endpoint. +// NewEndpoint creates a new icmp endpoint. It implements +// stack.TransportProtocol.NewEndpoint. func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue), nil + return newEndpoint(stack, netProto, p.number, waiterQueue, false) +} + +// NewRawEndpoint creates a new raw icmp endpoint. It implements +// stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return newEndpoint(stack, netProto, p.number, waiterQueue, true) } // MinimumPacketSize returns the minimum valid icmp packet size. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a8618bb4a..c48a27d8f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1441,7 +1441,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { s := newSegment(r, id, vv) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2f90839e9..ca53a076f 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -63,7 +63,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 753e1419e..639ad3fae 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -101,6 +101,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN return newEndpoint(stack, netProto, waiterQueue), nil } +// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + // MinimumPacketSize returns the minimum valid tcp packet size. func (*protocol) MinimumPacketSize() int { return header.TCPMinimumSize diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 05d35e526..44b9cdf6a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -934,7 +934,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { // Get the header then trim it from the view. hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index b3fbed6e4..616a9f388 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -48,6 +48,12 @@ func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolN return newEndpoint(stack, netProto, waiterQueue), nil } +// NewRawEndpoint creates a new raw UDP endpoint. Raw UDP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrUnknownProtocol +} + // MinimumPacketSize returns the minimum valid udp packet size. func (*protocol) MinimumPacketSize() int { return header.UDPMinimumSize |