diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/icmpv4.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv6.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 49 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 100 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/protocol.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 6 |
21 files changed, 301 insertions, 72 deletions
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c081de61f..c52c0d851 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -24,15 +24,11 @@ import ( type ICMPv4 []byte const ( - // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. - ICMPv4MinimumSize = 4 - - // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet. - ICMPv4EchoMinimumSize = 6 + // ICMPv4PayloadOffset defines the start of ICMP payload. + ICMPv4PayloadOffset = 4 - // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP - // destination unreachable packet. - ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4 + // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. + ICMPv4MinimumSize = 8 // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 @@ -104,5 +100,5 @@ func (ICMPv4) SetDestinationPort(uint16) { // Payload implements Transport.Payload. func (b ICMPv4) Payload() []byte { - return b[ICMPv4MinimumSize:] + return b[ICMPv4PayloadOffset:] } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 7da4c4845..94a3af289 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -85,6 +85,10 @@ const ( // units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // MinIPFragmentPayloadSize is the minimum number of payload bytes that + // the first fragment must carry when an IPv4 packet is fragmented. + MinIPFragmentPayloadSize = 8 + // IPv4AddressSize is the size, in bytes, of an IPv4 address. IPv4AddressSize = 4 @@ -268,6 +272,10 @@ func (b IPv4) IsValid(pktSize int) bool { return false } + if IPVersion(b) != IPv4Version { + return false + } + return true } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 7163eaa36..95fe8bfc3 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -184,6 +184,10 @@ func (b IPv6) IsValid(pktSize int) bool { return false } + if IPVersion(b) != IPv6Version { + return false + } + return true } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 1141443bb..82cfe785c 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -176,6 +176,21 @@ const ( // TCPProtocolNumber is TCP's transport protocol number. TCPProtocolNumber tcpip.TransportProtocolNumber = 6 + + // TCPMinimumMSS is the minimum acceptable value for MSS. This is the + // same as the value TCP_MIN_MSS defined net/tcp.h. + TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize + + // TCPMaximumMSS is the maximum acceptable value for MSS. + TCPMaximumMSS = 0xffff + + // TCPDefaultMSS is the MSS value that should be used if an MSS option + // is not received from the peer. It's also the value returned by + // TCP_MAXSEG option for a socket in an unconnected state. + // + // Per RFC 1122, page 85: "If an MSS option is not received at + // connection setup, TCP MUST assume a default send MSS of 536." + TCPDefaultMSS = 536 ) // SourcePort returns the "source port" field of the tcp header. @@ -306,7 +321,7 @@ func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { synOpts := TCPSynOptions{ // Per RFC 1122, page 85: "If an MSS option is not received at // connection setup, TCP MUST assume a default send MSS of 536." - MSS: 536, + MSS: TCPDefaultMSS, // If no window scale option is specified, WS in options is // returned as -1; this is because the absence of the option // indicates that the we cannot use window scaling on the diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ca3d6c0bf..cb35635fc 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -83,6 +83,10 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buf return tcpip.ErrNotSupported } +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + return tcpip.ErrNotSupported +} + func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { v := vv.First() h := header.ARP(v) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index db65ee7cc..8ff428445 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -282,10 +282,10 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8}, + {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") if err != nil { @@ -301,7 +301,7 @@ func TestIPv4ReceiveControl(t *testing.T) { } defer ep.Close() - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4 + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize view := buffer.NewView(dataOffset + 8) // Create the outer IPv4 header. @@ -319,10 +319,10 @@ func TestIPv4ReceiveControl(t *testing.T) { icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) icmp.SetType(header.ICMPv4DstUnreachable) icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef}) // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:]) + ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: 100, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index bc7f1c42a..fbef6947d 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -68,10 +68,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - if len(v) < header.ICMPv4EchoMinimumSize { - received.Invalid.Increment() - return - } // Only send a reply if the checksum is valid. wantChecksum := h.Checksum() @@ -93,9 +89,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) vv := vv.Clone(nil) - vv.TrimFront(header.ICMPv4EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4EchoMinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + vv.TrimFront(header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv4EchoReply) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) @@ -108,25 +104,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv4EchoMinimumSize { - received.Invalid.Increment() - return - } + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv4DstUnreachableMinimumSize { - received.Invalid.Increment() - return - } - vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize) + + vv.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, vv) case header.ICMPv4FragmentationNeeded: - mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:])) + mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:])) e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1e3a7425a..e44a73d96 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -232,6 +232,55 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } +// WriteHeaderIncludedPacket writes a packet already containing a network +// header through the given route. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + // The packet already has an IP header, but there are a few required + // checks. + ip := header.IPv4(payload.First()) + if !ip.IsValid(payload.Size()) { + return tcpip.ErrInvalidOptionValue + } + + // Always set the total length. + ip.SetTotalLength(uint16(payload.Size())) + + // Set the source address when zero. + if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, + // it will be part of the route. + ip.SetDestinationAddress(r.RemoteAddress) + + // Set the packet ID when zero. + if ip.ID() == 0 { + id := uint32(0) + if payload.Size() > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + } + ip.SetID(uint16(id)) + } + + // Always set the checksum. + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + if loop&stack.PacketLoop != 0 { + e.HandlePacket(r, payload) + } + if loop&stack.PacketOut == 0 { + return nil + } + + hdr := buffer.NewPrependableFromView(payload.ToView()) + r.Stats().IP.PacketsSent.Increment() + return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) +} + // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 27367d6c5..e3e8739fd 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -120,6 +120,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) } +// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet +// supported by IPv6. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + // TODO(b/119580726): Support IPv6 header-included packets. + return tcpip.ErrNotSupported +} + // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0ecaa0833..462265281 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -174,6 +174,10 @@ type NetworkEndpoint interface { // protocol. WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + // WriteHeaderIncludedPacket writes a packet that includes a network + // header to the given destination address. + WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -357,10 +361,19 @@ type TransportProtocolFactory func() TransportProtocol // instantiate network protocols. type NetworkProtocolFactory func() NetworkProtocol +// UnassociatedEndpointFactory produces endpoints for writing packets not +// associated with a particular transport protocol. Such endpoints can be used +// to write arbitrary packets that include the IP header. +type UnassociatedEndpointFactory interface { + NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) +} + var ( transportProtocols = make(map[string]TransportProtocolFactory) networkProtocols = make(map[string]NetworkProtocolFactory) + unassociatedFactory UnassociatedEndpointFactory + linkEPMu sync.RWMutex nextLinkEndpointID tcpip.LinkEndpointID = 1 linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) @@ -380,6 +393,13 @@ func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { networkProtocols[name] = p } +// RegisterUnassociatedFactory registers a factory to produce endpoints not +// associated with any particular transport protocol. This function is intended +// to be called by init() functions of the protocols. +func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { + unassociatedFactory = f +} + // RegisterLinkEndpoint register a link-layer protocol endpoint and returns an // ID that can be used to refer to it. func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 36d7b6ac7..391ab4344 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -163,6 +163,18 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec return err } +// WriteHeaderIncludedPacket writes a packet already containing a network +// header through the given route. +func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + r.ref.nic.stats.Tx.Packets.Increment() + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + return nil +} + // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { return r.ref.ep.DefaultTTL() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2d7f56ca9..3e8fb2a6c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -340,6 +340,8 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + unassociatedFactory UnassociatedEndpointFactory + demux *transportDemuxer stats tcpip.Stats @@ -442,6 +444,8 @@ func New(network []string, transport []string, opts Options) *Stack { } } + s.unassociatedFactory = unassociatedFactory + // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -574,11 +578,15 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // NewRawEndpoint creates a new raw transport layer endpoint of the given // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. -func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if !s.raw { return nil, tcpip.ErrNotPermitted } + if !associated { + return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue) + } + t, ok := s.transportProtocols[transport] if !ok { return nil, tcpip.ErrUnknownProtocol diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 69884af03..959071dbe 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -137,6 +137,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) } +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + return tcpip.ErrNotSupported +} + func (*fakeNetworkEndpoint) Close() {} type fakeNetGoodOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c61f96fb0..c4076666a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -496,6 +496,10 @@ type AvailableCongestionControlOption string // buffer moderation. type ModerateReceiveBufferOption bool +// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current +// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. +type MaxSegOption int + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index ab9e80747..a80ceafd0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c switch e.netProto { case header.IPv4ProtocolNumber: - err = e.send4(route, v) + err = send4(route, e.id.LocalPort, v) case header.IPv6ProtocolNumber: err = send6(route, e.id.LocalPort, v) @@ -352,20 +352,20 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { - if len(data) < header.ICMPv4EchoMinimumSize { +func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { + if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } // Set the ident to the user-specified port. Sequence number should // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort) + binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) - data = data[header.ICMPv4EchoMinimumSize:] + data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index c89538131..7fdba5d56 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -90,19 +90,18 @@ func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProt func (p *protocol) MinimumPacketSize() int { switch p.number { case ProtocolNumber4: - return header.ICMPv4EchoMinimumSize + return header.ICMPv4MinimumSize case ProtocolNumber6: return header.ICMPv6EchoMinimumSize } panic(fmt.Sprint("unknown protocol number: ", p.number)) } -// ParsePorts returns the source and destination ports stored in the given icmp -// packet. +// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { switch p.number { case ProtocolNumber4: - return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil + return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil case ProtocolNumber6: return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 34a14bf7f..bc4b255b4 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -21,6 +21,7 @@ go_library( "endpoint.go", "endpoint_state.go", "packet_list.go", + "protocol.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 42aded77f..a29587658 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -67,6 +67,7 @@ type endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + associated bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -97,8 +98,12 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET. +// TODO(b/129292371): IP_HDRINCL and AF_PACKET. func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if netProto != header.IPv4ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -110,6 +115,16 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + associated: associated, + } + + // Unassociated endpoints are write-only and users call Write() with IP + // headers included. Because they're write-only, We don't need to + // register with the stack. + if !associated { + ep.rcvBufSizeMax = 0 + ep.waiterQueue = nil + return ep, nil } if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { @@ -124,7 +139,7 @@ func (ep *endpoint) Close() { ep.mu.Lock() defer ep.mu.Unlock() - if ep.closed { + if ep.closed || !ep.associated { return } @@ -142,8 +157,11 @@ func (ep *endpoint) Close() { if ep.connected { ep.route.Release() + ep.connected = false } + ep.closed = true + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -152,6 +170,10 @@ func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if !ep.associated { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue + } + ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -192,6 +214,33 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, tcpip.ErrInvalidEndpointState } + payloadBytes, err := payload.Get(payload.Size()) + if err != nil { + ep.mu.RUnlock() + return 0, nil, err + } + + // If this is an unassociated socket and callee provided a nonzero + // destination address, route using that address. + if !ep.associated { + ip := header.IPv4(payloadBytes) + if !ip.IsValid(payload.Size()) { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidOptionValue + } + dstAddr := ip.DestinationAddress() + // Update dstAddr with the address in the IP header, unless + // opts.To is set (e.g. if sendto specifies a specific + // address). + if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { + opts.To = &tcpip.FullAddress{ + NIC: 0, // NIC is unset. + Addr: dstAddr, // The address from the payload. + Port: 0, // There are no ports here. + } + } + } + // Did the user caller provide a destination? If not, use the connected // destination. if opts.To == nil { @@ -216,12 +265,12 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, tcpip.ErrInvalidEndpointState } - n, ch, err := ep.finishWrite(payload, savedRoute) + n, ch, err := ep.finishWrite(payloadBytes, savedRoute) ep.mu.Unlock() return n, ch, err } - n, ch, err := ep.finishWrite(payload, &ep.route) + n, ch, err := ep.finishWrite(payloadBytes, &ep.route) ep.mu.RUnlock() return n, ch, err } @@ -248,7 +297,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, err } - n, ch, err := ep.finishWrite(payload, &route) + n, ch, err := ep.finishWrite(payloadBytes, &route) route.Release() ep.mu.RUnlock() return n, ch, err @@ -256,7 +305,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -269,13 +318,14 @@ func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uint } } - payloadBytes, err := payload.Get(payload.Size()) - if err != nil { - return 0, nil, err - } - switch ep.netProto { case header.IPv4ProtocolNumber: + if !ep.associated { + if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + return 0, nil, err + } + break + } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { return 0, nil, err @@ -335,15 +385,17 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } defer route.Release() - // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { - return err + if ep.associated { + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + ep.registeredNIC = nic } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - // Save the route and NIC we've connected via. + // Save the route we've connected via. ep.route = route.Clone() - ep.registeredNIC = nic ep.connected = true return nil @@ -386,14 +438,16 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrBadLocalAddress } - // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { - return err + if ep.associated { + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + ep.registeredNIC = addr.NIC + ep.boundNIC = addr.NIC } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = addr.NIC - ep.boundNIC = addr.NIC ep.boundAddr = addr.Addr ep.bound = true diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go new file mode 100644 index 000000000..783c21e6b --- /dev/null +++ b/pkg/tcpip/transport/raw/protocol.go @@ -0,0 +1,32 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +type factory struct{} + +// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. +func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) +} + +func init() { + stack.RegisterUnassociatedFactory(factory{}) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cb40fea94..beb90afb5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -117,6 +117,7 @@ const ( notifyDrain notifyReset notifyKeepaliveChanged + notifyMSSChanged ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -218,8 +219,6 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` id stack.TransportEndpointID - // state endpointState `state:".(endpointState)"` - // pState ProtocolState state EndpointState `state:".(EndpointState)"` isPortReserved bool `state:"manual"` @@ -313,6 +312,10 @@ type endpoint struct { // in SYN-RCVD state. synRcvdCount int + // userMSS if non-zero is the MSS value explicitly set by the user + // for this endpoint using the TCP_MAXSEG setsockopt. + userMSS int + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -917,6 +920,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1096,6 +1110,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err + case *tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + *o = header.TCPDefaultMSS + return nil + case *tcpip.SendBufferSizeOption: e.sndBufMu.Lock() *o = tcpip.SendBufferSizeOption(e.sndBufSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 630dd7925..bcc0f3e28 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -271,7 +271,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -291,8 +291,8 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp.SetType(typ) icmp.SetCode(code) - copy(icmp[header.ICMPv4MinimumSize:], p1) - copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) + copy(icmp[header.ICMPv4PayloadOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) |