summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/header/icmpv4.go30
-rw-r--r--pkg/tcpip/header/icmpv6.go27
-rw-r--r--pkg/tcpip/header/ipv6.go38
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go18
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/link/channel/channel.go14
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go99
-rw-r--r--pkg/tcpip/link/ethernet/ethernet_state_autogen.go3
-rw-r--r--pkg/tcpip/link/pipe/pipe.go110
-rw-r--r--pkg/tcpip/link/pipe/pipe_state_autogen.go3
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go8
-rw-r--r--pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go1
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go2
-rw-r--r--pkg/tcpip/link/tun/device.go4
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go49
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go17
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go135
-rw-r--r--pkg/tcpip/network/ipv6/mld.go12
-rw-r--r--pkg/tcpip/sock_err_list.go193
-rw-r--r--pkg/tcpip/socketops.go158
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go135
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go95
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go137
-rw-r--r--pkg/tcpip/stack/nic.go29
-rw-r--r--pkg/tcpip/stack/nud.go21
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go32
-rw-r--r--pkg/tcpip/stack/route.go172
-rw-r--r--pkg/tcpip/stack/stack.go31
-rw-r--r--pkg/tcpip/tcpip.go49
-rw-r--r--pkg/tcpip/tcpip_state_autogen.go121
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go21
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go97
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go12
-rw-r--r--pkg/tcpip/transport/tcp/snd.go35
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go210
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go120
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go39
43 files changed, 1609 insertions, 758 deletions
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 2f13dea6a..5f9b8e9e2 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -199,17 +200,24 @@ func (b ICMPv4) SetSequence(sequence uint16) {
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := uint16(0)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
+ xsum := ChecksumVV(vv, 0)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
+ return ^xsum
+}
- return xsum
+// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when
+// a packet having a `net` header causing an ICMP error.
+func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin {
+ switch net {
+ case IPv4ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP
+ case IPv6ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP6
+ default:
+ panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net))
+ }
}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 2eef64b4d..eca9750ab 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -265,22 +265,13 @@ func (b ICMPv6) Payload() []byte {
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := Checksum([]byte(src), 0)
- xsum = Checksum([]byte(dst), xsum)
- var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
- xsum = Checksum(upperLayerLength[:], xsum)
- xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
-
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
-
- return xsum
+ xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+
+ xsum = ChecksumVV(vv, xsum)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
+
+ return ^xsum
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index d522e5f10..5580d6a78 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -18,7 +18,6 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
- "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -153,13 +152,17 @@ const (
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
// be contained within this subnet.
-var IPv6EmptySubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
- if err != nil {
- panic(err)
- }
- return subnet
-}()
+var IPv6EmptySubnet = tcpip.AddressWithPrefix{
+ Address: IPv6Any,
+ PrefixLen: 0,
+}.Subnet()
+
+// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined
+// by RFC 4291 section 2.5.5.
+var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{
+ Address: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00",
+ PrefixLen: 96,
+}.Subnet()
// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
// by RFC 4291 section 2.5.6.
@@ -293,7 +296,7 @@ func IsV4MappedAddress(addr tcpip.Address) bool {
return false
}
- return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
+ return IPv4MappedIPv6Subnet.Contains(addr)
}
// IsV6MulticastAddress determines if the provided address is an IPv6
@@ -399,17 +402,6 @@ func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
}
-// IsV6UniqueLocalAddress determines if the provided address is an IPv6
-// unique-local address (within the prefix FC00::/7).
-func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
- if len(addr) != IPv6AddressSize {
- return false
- }
- // According to RFC 4193 section 3.1, a unique local address has the prefix
- // FC00::/7.
- return (addr[0] & 0xfe) == 0xfc
-}
-
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
//
@@ -456,9 +448,6 @@ const (
// LinkLocalScope indicates a link-local address.
LinkLocalScope IPv6AddressScope = iota
- // UniqueLocalScope indicates a unique-local address.
- UniqueLocalScope
-
// GlobalScope indicates a global address.
GlobalScope
)
@@ -476,9 +465,6 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
case IsV6LinkLocalAddress(addr):
return LinkLocalScope, nil
- case IsV6UniqueLocalAddress(addr):
- return UniqueLocalScope, nil
-
default:
return GlobalScope, nil
}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 1fbb2cc98..f18981332 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -18,6 +18,7 @@ import (
"bufio"
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
"math"
@@ -274,6 +275,10 @@ func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUn
return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
}
+// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
+// is malformed.
+var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
+
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
@@ -351,10 +356,15 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
continue
case ipv6RouterAlertHopByHopOptionIdentifier:
var routerAlertValue [ipv6RouterAlertPayloadLength]byte
- if n, err := i.reader.Read(routerAlertValue[:]); err != nil {
- panic(fmt.Sprintf("error when reading RouterAlert option's data bytes: %s", err))
- } else if n != ipv6RouterAlertPayloadLength {
- return nil, true, fmt.Errorf("read %d bytes for RouterAlert option, expected %d", n, ipv6RouterAlertPayloadLength)
+ if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil {
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ default:
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
+ }
+ } else if n != int(length) {
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
}
return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index a6d4fcd59..98bdd29db 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -36,10 +36,10 @@ const (
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
- // SrcPort is the "Source Port" field of a UDP packet.
+ // SrcPort is the "source port" field of a UDP packet.
SrcPort uint16
- // DstPort is the "Destination Port" field of a UDP packet.
+ // DstPort is the "destination port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
@@ -64,57 +64,52 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
-// SourcePort returns the "Source Port" field of the UDP header.
+// SourcePort returns the "source port" field of the udp header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
-// DestinationPort returns the "Destination Port" field of the UDP header.
+// DestinationPort returns the "destination port" field of the udp header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
-// Length returns the "Length" field of the UDP header.
+// Length returns the "length" field of the udp header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
- return b[:b.Length()][UDPMinimumSize:]
+ return b[UDPMinimumSize:]
}
-// Checksum returns the "checksum" field of the UDP header.
+// Checksum returns the "checksum" field of the udp header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
-// SetSourcePort sets the "source port" field of the UDP header.
+// SetSourcePort sets the "source port" field of the udp header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
-// SetDestinationPort sets the "destination port" field of the UDP header.
+// SetDestinationPort sets the "destination port" field of the udp header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
-// SetChecksum sets the "checksum" field of the UDP header.
+// SetChecksum sets the "checksum" field of the udp header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// SetLength sets the "length" field of the UDP header.
+// SetLength sets the "length" field of the udp header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
-// PayloadLength returns the length of the payload following the UDP header.
-func (b UDP) PayloadLength() uint16 {
- return b.Length() - UDPMinimumSize
-}
-
-// CalculateChecksum calculates the checksum of the UDP packet, given the
+// CalculateChecksum calculates the checksum of the udp packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 0efbfb22b..d9f8e3b35 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -31,7 +31,7 @@ type PacketInfo struct {
Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
- Route *stack.Route
+ Route stack.RouteInfo
}
// Notification is the interface for receiving notification from the packet
@@ -230,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket stores outbound packets into the channel.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
e.q.Write(p)
@@ -248,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets stores outbound packets into the channel.
func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
if !e.q.Write(p) {
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
new file mode 100644
index 000000000..beefcd008
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ethernet provides an implementation of an ethernet link endpoint that
+// wraps an inner link endpoint.
+package ethernet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns an ethernet link endpoint that wraps an inner link endpoint.
+func New(ep stack.LinkEndpoint) *Endpoint {
+ var e Endpoint
+ e.Endpoint.Init(ep, &e)
+ return &e
+}
+
+// Endpoint is an ethernet endpoint.
+//
+// It adds an ethernet header to packets before sending them out through its
+// inner link endpoint and consumes an ethernet header before sending the
+// packet to the stack.
+type Endpoint struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ return
+ }
+
+ eth := header.Ethernet(hdr)
+ if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt)
+ }
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt)
+ return e.Endpoint.WritePacket(r, gso, proto, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ linkAddr := e.Endpoint.LinkAddress()
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ fields := header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: proto,
+ }
+ eth.Encode(&fields)
+}
diff --git a/pkg/tcpip/link/ethernet/ethernet_state_autogen.go b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go
new file mode 100644
index 000000000..71d255c20
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package ethernet
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
new file mode 100644
index 000000000..25c364391
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -0,0 +1,110 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides the implementation of pipe-like data-link layer
+// endpoints. Such endpoints allow packets to be sent between two interfaces.
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns both ends of a new pipe.
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) {
+ ep1 := &Endpoint{
+ linkAddr: linkAddr1,
+ }
+ ep2 := &Endpoint{
+ linkAddr: linkAddr2,
+ }
+ ep1.linked = ep2
+ ep2.linked = ep1
+ return ep1, ep2
+}
+
+// Endpoint is one end of a pipe.
+type Endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ linkAddr tcpip.LinkAddress
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if !e.linked.IsAttached() {
+ return nil
+ }
+
+ // Note that the local address from the perspective of this endpoint is the
+ // remote address from the perspective of the other end of the pipe
+ // (e.linked). Similarly, the remote address from the perspective of this
+ // endpoint is the local address on the other end.
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// Wait implements stack.LinkEndpoint.
+func (*Endpoint) Wait() {}
+
+// MTU implements stack.LinkEndpoint.
+func (*Endpoint) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/pipe/pipe_state_autogen.go b/pkg/tcpip/link/pipe/pipe_state_autogen.go
new file mode 100644
index 000000000..d3b40feb4
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package pipe
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 27667f5f0..b7458b620 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -154,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
- newRoute := r.Clone()
- pkt.EgressRoute = newRoute
+ pkt.EgressRoute = r
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
@@ -178,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
for pkt := pkts.Front(); pkt != nil; {
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
nxt := pkt.Next()
- // Since qdisc can hold onto a packet for long we should Clone
- // the route here to ensure it doesn't get released while the
- // packet is still in our queue.
- newRoute := pkt.EgressRoute.Clone()
- pkt.EgressRoute = newRoute
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
index eb5abb906..45adcbccb 100644
--- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
+++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
@@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
+ s.EgressRoute.Acquire()
q.list.PushBack(s)
q.used++
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 8d9a91020..1a2cc39eb 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -263,7 +263,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- if parse.ARP(pkt) {
+ if !parse.ARP(pkt) {
return
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index a364c5801..bfac358f4 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
// If the packet does not already have link layer header, and the route
// does not exist, we can't compute it. This is possibly a raw packet, tun
// device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" {
+ if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 {
return nil, false
}
@@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader().View().IsEmpty() {
- d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
vv.AppendView(info.Pkt.LinkHeader().View())
}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index f85c5ff9d..f2f0e069c 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -131,17 +131,6 @@ type multicastGroupState struct {
// GenericMulticastProtocolOptions holds options for the generic multicast
// protocol.
type GenericMulticastProtocolOptions struct {
- // Enabled indicates whether the generic multicast protocol will be
- // performed.
- //
- // When enabled, the protocol may transmit report and leave messages when
- // joining and leaving multicast groups respectively, and handle incoming
- // packets.
- //
- // When disabled, the protocol will still keep track of locally joined groups,
- // it just won't transmit and handle packets, or update groups' state.
- Enabled bool
-
// Rand is the source of random numbers.
Rand *rand.Rand
@@ -170,6 +159,17 @@ type GenericMulticastProtocolOptions struct {
// MulticastGroupProtocol is a multicast group protocol whose core state machine
// can be represented by GenericMulticastProtocolState.
type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
// SendReport sends a multicast report for the specified group address.
//
// Returns false if the caller should queue the report to be sent later. Note,
@@ -196,6 +196,9 @@ type MulticastGroupProtocol interface {
//
// GenericMulticastProtocolState.Init MUST be called before calling any of
// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
type GenericMulticastProtocolState struct {
// Do not allow overwriting this state.
_ sync.NoCopy
@@ -235,9 +238,11 @@ func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts Gene
//
// The groups will still be considered joined locally.
//
+// MUST be called when the multicast group protocol is disabled.
+//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -255,7 +260,7 @@ func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -290,12 +295,8 @@ func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
// JoinGroupLocked handles joining a new group.
//
-// If dontInitialize is true, the group will be not be initialized and will be
-// left in the non-member state - no packets will be sent for it until it is
-// initialized via InitializeGroups.
-//
// Precondition: g.protocolMU must be locked.
-func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) {
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
if info, ok := g.memberships[groupAddress]; ok {
// The group has already been joined.
info.joins++
@@ -310,6 +311,10 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre
state: nonMember,
lastToSendReport: false,
delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
info, ok := g.memberships[groupAddress]
if !ok {
panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
@@ -320,7 +325,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre
}),
}
- if !dontInitialize && g.opts.Enabled {
+ if g.opts.Protocol.Enabled() {
g.initializeNewMemberLocked(groupAddress, &info)
}
@@ -372,7 +377,7 @@ func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Addr
//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -406,7 +411,7 @@ func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Add
//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -518,7 +523,7 @@ func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddres
// maybeSendLeave attempts to send a leave message.
func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
- if !g.opts.Enabled || !lastToSendReport {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
return
}
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index fb7a9e68e..da88d65d1 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -72,8 +72,6 @@ type igmpState struct {
// The IPv4 endpoint this igmpState is for.
ep *endpoint
- enabled bool
-
genericMulticastProtocol ip.GenericMulticastProtocolState
// igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
@@ -95,6 +93,13 @@ type igmpState struct {
igmpV1Job *tcpip.Job
}
+// Enabled implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) Enabled() bool {
+ // No need to perform IGMP on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
+}
+
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
@@ -127,11 +132,7 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// Must only be called once for the lifetime of igmp.
func (igmp *igmpState) init(ep *endpoint) {
igmp.ep = ep
- // No need to perform IGMP on loopback interfaces since they don't have
- // neighbouring nodes.
- igmp.enabled = ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback()
igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
- Enabled: igmp.enabled,
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
@@ -223,7 +224,7 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
- if igmp.enabled && maxRespTime == 0 {
+ if maxRespTime == 0 && igmp.Enabled() {
igmp.igmpV1Job.Cancel()
igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.setV1Present(true)
@@ -296,7 +297,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
- igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress, !igmp.ep.Enabled() /* dontInitialize */)
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index a49b5ac77..f2018d073 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -61,6 +61,108 @@ const (
buckets = 2048
)
+// policyTable is the default policy table defined in RFC 6724 section 2.1.
+//
+// A more human-readable version:
+//
+// Prefix Precedence Label
+// ::1/128 50 0
+// ::/0 40 1
+// ::ffff:0:0/96 35 4
+// 2002::/16 30 2
+// 2001::/32 5 5
+// fc00::/7 3 13
+// ::/96 1 3
+// fec0::/10 1 11
+// 3ffe::/16 1 12
+//
+// The table is sorted by prefix length so longest-prefix match can be easily
+// achieved.
+//
+// We willingly left out ::/96, fec0::/10 and 3ffe::/16 since those prefix
+// assignments are deprecated.
+//
+// As per RFC 4291 section 2.5.5.1 (for ::/96),
+//
+// The "IPv4-Compatible IPv6 address" is now deprecated because the
+// current IPv6 transition mechanisms no longer use these addresses.
+// New or updated implementations are not required to support this
+// address type.
+//
+// As per RFC 3879 section 4 (for fec0::/10),
+//
+// This document formally deprecates the IPv6 site-local unicast prefix
+// defined in [RFC3513], i.e., 1111111011 binary or FEC0::/10.
+//
+// As per RFC 3701 section 1 (for 3ffe::/16),
+//
+// As clearly stated in [TEST-NEW], the addresses for the 6bone are
+// temporary and will be reclaimed in the future. It further states
+// that all users of these addresses (within the 3FFE::/16 prefix) will
+// be required to renumber at some time in the future.
+//
+// and section 2,
+//
+// Thus after the pTLA allocation cutoff date January 1, 2004, it is
+// REQUIRED that no new 6bone 3FFE pTLAs be allocated.
+//
+// MUST NOT BE MODIFIED.
+var policyTable = [...]struct {
+ subnet tcpip.Subnet
+
+ label uint8
+}{
+ // ::1/128
+ {
+ subnet: header.IPv6Loopback.WithPrefix().Subnet(),
+ label: 0,
+ },
+ // ::ffff:0:0/96
+ {
+ subnet: header.IPv4MappedIPv6Subnet,
+ label: 4,
+ },
+ // 2001::/32 (Teredo prefix as per RFC 4380 section 2.6).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 32,
+ }.Subnet(),
+ label: 5,
+ },
+ // 2002::/16 (6to4 prefix as per RFC 3056 section 2).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 16,
+ }.Subnet(),
+ label: 2,
+ },
+ // fc00::/7 (Unique local addresses as per RFC 4193 section 3.1).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 7,
+ }.Subnet(),
+ label: 13,
+ },
+ // ::/0
+ {
+ subnet: header.IPv6EmptySubnet,
+ label: 1,
+ },
+}
+
+func getLabel(addr tcpip.Address) uint8 {
+ for _, p := range policyTable {
+ if p.subnet.Contains(addr) {
+ return p.label
+ }
+ }
+
+ panic(fmt.Sprintf("should have a label for address = %s", addr))
+}
+
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -1373,7 +1475,11 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// RFC 6724 section 5.
type addrCandidate struct {
addressEndpoint stack.AddressEndpoint
+ addr tcpip.Address
scope header.IPv6AddressScope
+
+ label uint8
+ matchingPrefix uint8
}
if len(remoteAddr) == 0 {
@@ -1400,7 +1506,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
cs = append(cs, addrCandidate{
addressEndpoint: addressEndpoint,
+ addr: addr,
scope: scope,
+ label: getLabel(addr),
+ matchingPrefix: remoteAddr.MatchingPrefix(addr),
})
return true
@@ -1412,18 +1521,20 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
}
+ remoteLabel := getLabel(remoteAddr)
+
// Sort the addresses as per RFC 6724 section 5 rules 1-3.
//
- // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ // TODO(b/146021396): Implement rules 4, 5 of RFC 6724 section 5.
sort.Slice(cs, func(i, j int) bool {
sa := cs[i]
sb := cs[j]
// Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sa.addr == remoteAddr {
return true
}
- if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sb.addr == remoteAddr {
return false
}
@@ -1440,11 +1551,29 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
return sbDep
}
+ // Prefer matching label as per RFC 6724 section 5 rule 6.
+ if sa, sb := sa.label == remoteLabel, sb.label == remoteLabel; sa != sb {
+ if sa {
+ return true
+ }
+ if sb {
+ return false
+ }
+ }
+
// Prefer temporary addresses as per RFC 6724 section 5 rule 7.
if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
return saTemp
}
+ // Use longest matching prefix as per RFC 6724 section 5 rule 8.
+ if sa.matchingPrefix > sb.matchingPrefix {
+ return true
+ }
+ if sb.matchingPrefix > sa.matchingPrefix {
+ return false
+ }
+
// sa and sb are equal, return the endpoint that is closest to the front of
// the primary endpoint list.
return i < j
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index 6f64b8462..e8d1e7a79 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -58,6 +58,13 @@ type mldState struct {
genericMulticastProtocol ip.GenericMulticastProtocolState
}
+// Enabled implements ip.MulticastGroupProtocol.
+func (mld *mldState) Enabled() bool {
+ // No need to perform MLD on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
+}
+
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
@@ -80,9 +87,6 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
func (mld *mldState) init(ep *endpoint) {
mld.ep = ep
mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
- // No need to perform MLD on loopback interfaces since they don't have
- // neighbouring nodes.
- Enabled: ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback(),
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
@@ -112,7 +116,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
- mld.genericMulticastProtocol.JoinGroupLocked(groupAddress, !mld.ep.Enabled() /* dontInitialize */)
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.
diff --git a/pkg/tcpip/sock_err_list.go b/pkg/tcpip/sock_err_list.go
new file mode 100644
index 000000000..8935a8793
--- /dev/null
+++ b/pkg/tcpip/sock_err_list.go
@@ -0,0 +1,193 @@
+package tcpip
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type sockErrorElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (sockErrorElementMapper) linkerFor(elem *SockError) *SockError { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type sockErrorList struct {
+ head *SockError
+ tail *SockError
+}
+
+// Reset resets list l to the empty state.
+func (l *sockErrorList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *sockErrorList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *sockErrorList) Front() *SockError {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *sockErrorList) Back() *SockError {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+func (l *sockErrorList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (sockErrorElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *sockErrorList) PushFront(e *SockError) {
+ linker := sockErrorElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ sockErrorElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *sockErrorList) PushBack(e *SockError) {
+ linker := sockErrorElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ sockErrorElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *sockErrorList) PushBackList(m *sockErrorList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ sockErrorElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ sockErrorElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *sockErrorList) InsertAfter(b, e *SockError) {
+ bLinker := sockErrorElementMapper{}.linkerFor(b)
+ eLinker := sockErrorElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ sockErrorElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *sockErrorList) InsertBefore(a, e *SockError) {
+ aLinker := sockErrorElementMapper{}.linkerFor(a)
+ eLinker := sockErrorElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ sockErrorElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *sockErrorList) Remove(e *SockError) {
+ linker := sockErrorElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ sockErrorElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ sockErrorElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type sockErrorEntry struct {
+ next *SockError
+ prev *SockError
+}
+
+// Next returns the entry that follows e in the list.
+func (e *sockErrorEntry) Next() *SockError {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *sockErrorEntry) Prev() *SockError {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *sockErrorEntry) SetNext(elem *SockError) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *sockErrorEntry) SetPrev(elem *SockError) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index b60a5fd76..f3ad40fdf 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -42,6 +42,12 @@ type SocketOptionsHandler interface {
// LastError is invoked when SO_ERROR is read for an endpoint.
LastError() *Error
+
+ // UpdateLastError updates the endpoint specific last error field.
+ UpdateLastError(err *Error)
+
+ // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
+ HasNIC(v int32) bool
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -70,6 +76,14 @@ func (*DefaultSocketOptionsHandler) LastError() *Error {
return nil
}
+// UpdateLastError implements SocketOptionsHandler.UpdateLastError.
+func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {}
+
+// HasNIC implements SocketOptionsHandler.HasNIC.
+func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
+ return false
+}
+
// SocketOptions contains all the variables which store values for SOL_SOCKET,
// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
@@ -104,7 +118,7 @@ type SocketOptions struct {
keepAliveEnabled uint32
// multicastLoopEnabled determines whether multicast packets sent over a
- // non-loopback interface will be looped back. Analogous to inet->mc_loop.
+ // non-loopback interface will be looped back.
multicastLoopEnabled uint32
// receiveTOSEnabled is used to specify if the TOS ancillary message is
@@ -145,6 +159,17 @@ type SocketOptions struct {
// the incoming packet should be returned as an ancillary message.
receiveOriginalDstAddress uint32
+ // recvErrEnabled determines whether extended reliable error message passing
+ // is enabled.
+ recvErrEnabled uint32
+
+ // errQueue is the per-socket error queue. It is protected by errQueueMu.
+ errQueueMu sync.Mutex `state:"nosave"`
+ errQueue sockErrorList
+
+ // bindToDevice determines the device to which the socket is bound.
+ bindToDevice int32
+
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -167,6 +192,11 @@ func storeAtomicBool(addr *uint32, v bool) {
atomic.StoreUint32(addr, val)
}
+// SetLastError sets the last error for a socket.
+func (so *SocketOptions) SetLastError(err *Error) {
+ so.handler.UpdateLastError(err)
+}
+
// GetBroadcast gets value for SO_BROADCAST option.
func (so *SocketOptions) GetBroadcast() bool {
return atomic.LoadUint32(&so.broadcastEnabled) != 0
@@ -334,6 +364,19 @@ func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
storeAtomicBool(&so.receiveOriginalDstAddress, v)
}
+// GetRecvError gets value for IP*_RECVERR option.
+func (so *SocketOptions) GetRecvError() bool {
+ return atomic.LoadUint32(&so.recvErrEnabled) != 0
+}
+
+// SetRecvError sets value for IP*_RECVERR option.
+func (so *SocketOptions) SetRecvError(v bool) {
+ storeAtomicBool(&so.recvErrEnabled, v)
+ if !v {
+ so.pruneErrQueue()
+ }
+}
+
// GetLastError gets value for SO_ERROR option.
func (so *SocketOptions) GetLastError() *Error {
return so.handler.LastError()
@@ -362,3 +405,116 @@ func (so *SocketOptions) SetLinger(linger LingerOption) {
so.linger = linger
so.mu.Unlock()
}
+
+// SockErrOrigin represents the constants for error origin.
+type SockErrOrigin uint8
+
+const (
+ // SockExtErrorOriginNone represents an unknown error origin.
+ SockExtErrorOriginNone SockErrOrigin = iota
+
+ // SockExtErrorOriginLocal indicates a local error.
+ SockExtErrorOriginLocal
+
+ // SockExtErrorOriginICMP indicates an IPv4 ICMP error.
+ SockExtErrorOriginICMP
+
+ // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error.
+ SockExtErrorOriginICMP6
+)
+
+// IsICMPErr indicates if the error originated from an ICMP error.
+func (origin SockErrOrigin) IsICMPErr() bool {
+ return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6
+}
+
+// SockError represents a queue entry in the per-socket error queue.
+//
+// +stateify savable
+type SockError struct {
+ sockErrorEntry
+
+ // Err is the error caused by the errant packet.
+ Err *Error
+ // ErrOrigin indicates the error origin.
+ ErrOrigin SockErrOrigin
+ // ErrType is the type in the ICMP header.
+ ErrType uint8
+ // ErrCode is the code in the ICMP header.
+ ErrCode uint8
+ // ErrInfo is additional info about the error.
+ ErrInfo uint32
+
+ // Payload is the errant packet's payload.
+ Payload []byte
+ // Dst is the original destination address of the errant packet.
+ Dst FullAddress
+ // Offender is the original sender address of the errant packet.
+ Offender FullAddress
+ // NetProto is the network protocol being used to transmit the packet.
+ NetProto NetworkProtocolNumber
+}
+
+// pruneErrQueue resets the queue.
+func (so *SocketOptions) pruneErrQueue() {
+ so.errQueueMu.Lock()
+ so.errQueue.Reset()
+ so.errQueueMu.Unlock()
+}
+
+// DequeueErr dequeues a socket extended error from the error queue and returns
+// it. Returns nil if queue is empty.
+func (so *SocketOptions) DequeueErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+
+ err := so.errQueue.Front()
+ if err != nil {
+ so.errQueue.Remove(err)
+ }
+ return err
+}
+
+// PeekErr returns the error in the front of the error queue. Returns nil if
+// the error queue is empty.
+func (so *SocketOptions) PeekErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ return so.errQueue.Front()
+}
+
+// QueueErr inserts the error at the back of the error queue.
+//
+// Preconditions: so.GetRecvError() == true.
+func (so *SocketOptions) QueueErr(err *SockError) {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ so.errQueue.PushBack(err)
+}
+
+// QueueLocalErr queues a local error onto the local queue.
+func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
+ so.QueueErr(&SockError{
+ Err: err,
+ ErrOrigin: SockExtErrorOriginLocal,
+ ErrInfo: info,
+ Payload: payload,
+ Dst: dst,
+ NetProto: net,
+ })
+}
+
+// GetBindToDevice gets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) GetBindToDevice() int32 {
+ return atomic.LoadInt32(&so.bindToDevice)
+}
+
+// SetBindToDevice sets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error {
+ if !so.handler.HasNIC(bindToDevice) {
+ return ErrUnknownDevice
+ }
+
+ atomic.StoreInt32(&so.bindToDevice, bindToDevice)
+ return nil
+}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index c9b13cd0e..792f4f170 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -18,7 +18,6 @@ import (
"fmt"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -58,9 +57,6 @@ const (
incomplete entryState = iota
// ready means that the address has been resolved and can be used.
ready
- // failed means that address resolution timed out and the address
- // could not be resolved.
- failed
)
// String implements Stringer.
@@ -70,8 +66,6 @@ func (s entryState) String() string {
return "incomplete"
case ready:
return "ready"
- case failed:
- return "failed"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -80,40 +74,48 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ // linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
+ // TODO(gvisor.dev/issue/5150): move these fields under mu.
+ // mu protects the fields below.
+ mu sync.RWMutex
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
- // wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of incomplete these waiters are notified.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil iff
- // s is incomplete and resolution is not yet in progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
}
-// changeState sets the entry's state to ns, notifying any waiters.
+func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
+ for _, callback := range e.onResolve {
+ callback(linkAddr, len(linkAddr) != 0)
+ }
+ e.onResolve = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// changeStateLocked sets the entry's state to ns.
//
// The entry's expiration is bumped up to the greater of itself and the passed
// expiration; the zero value indicates immediate expiration, and is set
// unconditionally - this is an implementation detail that allows for entries
// to be reused.
-func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
- // Notify whoever is waiting on address resolution when transitioning
- // out of incomplete.
- if e.s == incomplete && ns != incomplete {
- for w := range e.wakers {
- w.Assert()
- }
- e.wakers = nil
- if ch := e.done; ch != nil {
- close(ch)
- }
- e.done = nil
+//
+// Precondition: e.mu must be locked
+func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
+ if e.s == incomplete && ns == ready {
+ e.notifyCompletionLocked(e.linkAddr)
}
if expiration.IsZero() || expiration.After(e.expiration) {
@@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
e.s = ns
}
-func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
- delete(e.wakers, w)
-}
-
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
@@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
c.cache.Lock()
entry := c.getOrCreateEntryLocked(k)
- entry.linkAddr = v
-
- entry.changeState(ready, expiration)
c.cache.Unlock()
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+ entry.linkAddr = v
+ entry.changeStateLocked(ready, expiration)
}
// getOrCreateEntryLocked retrieves a cache entry associated with k. The
@@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
var entry *linkAddrEntry
if len(c.cache.table) == linkAddrCacheSize {
entry = c.cache.lru.Back()
+ entry.mu.Lock()
delete(c.cache.table, entry.addr)
c.cache.lru.Remove(entry)
- // Wake waiters and mark the soon-to-be-reused entry as expired. Note
- // that the state passed doesn't matter when the zero time is passed.
- entry.changeState(failed, time.Time{})
+ // Wake waiters and mark the soon-to-be-reused entry as expired.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ entry.mu.Unlock()
} else {
entry = new(linkAddrEntry)
}
@@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ if onResolve != nil {
+ onResolve(addr, true)
+ }
return addr, nil, nil
}
}
@@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
+ case ready:
if !time.Now().After(entry.expiration) {
// Not expired.
- switch s {
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ if onResolve != nil {
+ onResolve(entry.linkAddr, true)
}
+ return entry.linkAddr, nil, nil
}
- entry.changeState(incomplete, time.Time{})
+ entry.changeStateLocked(incomplete, time.Time{})
fallthrough
case incomplete:
- if waker != nil {
- if entry.wakers == nil {
- entry.wakers = make(map[*sleep.Waker]struct{})
- }
- entry.wakers[waker] = struct{}{}
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
}
-
if entry.done == nil {
- // Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- }
-
entry.done = make(chan struct{})
go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
-
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker previously added through get().
-func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.cache.Lock()
- defer c.cache.Unlock()
-
- if entry, ok := c.cache.table[k]; ok {
- entry.removeWaker(waker)
- }
-}
-
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
@@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
}
}
-// checkLinkRequest checks whether previous attempt to resolve address has succeeded
-// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
-// can stop, false if another request should be sent.
+// checkLinkRequest checks whether previous attempt to resolve address has
+// succeeded and mark the entry accordingly. Returns true if request can stop,
+// false if another request should be sent.
func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
c.cache.Lock()
defer c.cache.Unlock()
@@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
// Entry was evicted from the cache.
return true
}
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
- // Entry was made ready by resolver or failed. Either way we're done.
+ case ready:
+ // Entry was made ready by resolver.
case incomplete:
if attempt+1 < c.resolutionAttempts {
// No response yet, need to send another ARP request.
return false
}
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed, now.Add(c.ageLimit))
+ // Max number of retries reached, delete entry.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ delete(c.cache.table, k)
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 317f6871d..c15f10e76 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -99,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
n.dynamic.lru.Remove(e)
n.dynamic.count--
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Unknown)
- e.notifyWakersLocked()
+ e.removeLocked()
e.mu.Unlock()
}
n.cache[remoteAddr] = entry
@@ -110,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
return entry
}
-// entry looks up the neighbor cache for translating address to link address
-// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
-// is a LinkAddressResolver registered with the network protocol, the cache
-// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
-// provided, it will be notified when address resolution is complete (success
-// or not).
+// entry looks up neighbor information matching the remote address, and returns
+// it if readily available.
+//
+// Returns ErrWouldBlock if the link address is not readily available, along
+// with a notification channel for the caller to block on. Triggers address
+// resolution asynchronously.
+//
+// If onResolve is provided, it will be called either immediately, if resolution
+// is not required, or when address resolution is complete, with the resolved
+// link address and whether resolution succeeded. After any callbacks have been
+// called, the returned notification channel is closed.
+//
+// NB: if a callback is provided, it should not call into the neighbor cache.
//
// If specified, the local address must be an address local to the interface the
// neighbor cache belongs to. The local address is the source address of a
// packet prompting NUD/link address resolution.
//
-// If address resolution is required, ErrNoLinkAddress and a notification
-// channel is returned for the top level caller to block. Channel is closed
-// once address resolution is complete (success or not).
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve.
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
Addr: remoteAddr,
@@ -132,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
State: Static,
UpdatedAtNanos: 0,
}
+ if onResolve != nil {
+ onResolve(linkAddr, true)
+ }
return e, nil, nil
}
@@ -149,37 +155,25 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// of packets to a neighbor. While reasserting a neighbor's reachability,
// a node continues sending packets to that neighbor using the cached
// link-layer address."
+ if onResolve != nil {
+ onResolve(entry.neigh.LinkAddr, true)
+ }
return entry.neigh, nil, nil
- case Unknown, Incomplete:
- entry.addWakerLocked(w)
-
+ case Unknown, Incomplete, Failed:
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
+ }
if entry.done == nil {
// Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
- }
entry.done = make(chan struct{})
}
-
entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
- case Failed:
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
- n.mu.Lock()
- if entry, ok := n.cache[addr]; ok {
- delete(entry.wakers, waker)
- }
- n.mu.Unlock()
-}
-
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
n.mu.RLock()
@@ -222,34 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
return
}
- // Notify that resolution has been interrupted, just in case the entry was
- // in the Incomplete or Probe state.
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
-// removeEntryLocked removes the specified entry from the neighbor cache.
-//
-// Prerequisite: n.mu and entry.mu MUST be locked.
-func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
- if entry.neigh.State != Static {
- n.dynamic.lru.Remove(entry)
- n.dynamic.count--
- }
- if entry.neigh.State != Failed {
- entry.dispatchRemoveEventLocked()
- }
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
-
- delete(n.cache, entry.neigh.Addr)
-}
-
// removeEntry removes a dynamic or static entry by address from the neighbor
// cache. Returns true if the entry was found and deleted.
func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
@@ -264,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
entry.mu.Lock()
defer entry.mu.Unlock()
- n.removeEntryLocked(entry)
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+
+ entry.removeLocked()
+ delete(n.cache, entry.neigh.Addr)
return true
}
@@ -275,9 +254,7 @@ func (n *neighborCache) clear() {
for _, entry := range n.cache {
entry.mu.Lock()
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 32399b4f5..75afb3001 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -19,7 +19,6 @@ import (
"sync"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -67,8 +66,7 @@ const (
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
- // Failed means traffic should not be sent to this neighbor since attempts of
- // reachability have returned inconclusive.
+ // Failed means recent attempts of reachability have returned inconclusive.
Failed
)
@@ -93,16 +91,13 @@ type neighborEntry struct {
neigh NeighborEntry
- // wakers is a set of waiters for address resolution result. Anytime state
- // transitions out of incomplete these waiters are notified. It is nil iff
- // address resolution is ongoing and no clients are waiting for the result.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil
- // iff nudState is not Reachable and address resolution is not yet in
- // progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
+
isRouter bool
job *tcpip.Job
}
@@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
}
}
-// addWaker adds w to the list of wakers waiting for address resolution.
-// Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
- if w == nil {
- return
- }
- if e.wakers == nil {
- e.wakers = make(map[*sleep.Waker]struct{})
- }
- e.wakers[w] = struct{}{}
-}
-
-// notifyWakersLocked notifies those waiting for address resolution, whether it
-// succeeded or failed. Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) notifyWakersLocked() {
- for w := range e.wakers {
- w.Assert()
+// notifyCompletionLocked notifies those waiting for address resolution, with
+// the link address if resolution completed successfully.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ for _, callback := range e.onResolve {
+ callback(e.neigh.LinkAddr, succeeded)
}
- e.wakers = nil
+ e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
@@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
@@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() {
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
@@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() {
// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
// has been removed.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
+// cancelJobLocked cancels the currently scheduled action, if there is one.
+// Entries in Unknown, Stale, or Static state do not have a scheduled action.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) cancelJobLocked() {
+ if job := e.job; job != nil {
+ job.Cancel()
+ }
+}
+
+// removeLocked prepares the entry for removal.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) removeLocked() {
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.dispatchRemoveEventLocked()
+ e.cancelJobLocked()
+ e.notifyCompletionLocked(false /* succeeded */)
+}
+
// setStateLocked transitions the entry to the specified state immediately.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
-// e.mu MUST be locked.
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
- // Cancel the previously scheduled action, if there is one. Entries in
- // Unknown, Stale, or Static state do not have scheduled actions.
- if timer := e.job; timer != nil {
- timer.Cancel()
- }
+ e.cancelJobLocked()
prev := e.neigh.State
e.neigh.State = next
@@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
e.job.Schedule(immediateDuration)
case Failed:
- e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() {
- e.nic.neigh.removeEntryLocked(e)
- })
- e.job.Schedule(config.UnreachableTime)
+ e.notifyCompletionLocked(false /* succeeded */)
case Unknown, Stale, Static:
// Do nothing
@@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
+
+ fallthrough
case Unknown:
e.neigh.State = Incomplete
e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
@@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// implementation may find it convenient in some cases to return errors
// to the sender by taking the offending packet, generating an ICMP
// error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
+ // error-handling routines." - RFC 4861 section 2.1
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
@@ -349,8 +358,6 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
- case Failed:
- e.nic.stats.Neighbor.FailedEntryLookups.Increment()
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -360,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// Neighbor Solicitation for ARP or NDP, respectively).
//
// Follows the logic defined in RFC 4861 section 7.2.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// Probes MUST be silently discarded if the target address is tentative, does
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
switch e.neigh.State {
- case Unknown, Incomplete, Failed:
+ case Unknown, Failed:
e.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
- e.notifyWakersLocked()
e.dispatchAddEventLocked()
+ case Incomplete:
+ // "If an entry already exists, and the cached link-layer address
+ // differs from the one in the received Source Link-Layer option, the
+ // cached address should be replaced by the received address, and the
+ // entry's reachability state MUST be set to STALE."
+ // - RFC 4861 section 7.2.3
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.setStateLocked(Stale)
+ e.notifyCompletionLocked(true /* succeeded */)
+ e.dispatchChangeEventLocked()
+
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
@@ -404,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// not be possible. SEND uses RSA key pairs to produce Cryptographically
// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
// claimed source of an NDP message is the owner of the claimed address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
switch e.neigh.State {
case Incomplete:
@@ -422,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
@@ -457,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
if !wasReachable {
e.dispatchChangeEventLocked()
}
@@ -495,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
// (e.g. TCP acknowledgements) reachability confirmation.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
@@ -512,23 +535,3 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
}
-
-// doubleLock combines two locks into one while maintaining lock ordering.
-//
-// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed
-// neighbor is allowed.
-type doubleLock struct {
- first, second sync.Locker
-}
-
-// Lock locks both locks in order: first then second.
-func (l *doubleLock) Lock() {
- l.first.Lock()
- l.second.Lock()
-}
-
-// Unlock unlocks both locks in reverse order: second then first.
-func (l *doubleLock) Unlock() {
- l.second.Unlock()
- l.first.Unlock()
-}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 5d037a27e..4a34805b5 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -20,7 +20,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -295,15 +294,17 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// the same unresolved IP address, and transmit the saved
// packet when the address has been resolved.
//
- // RFC 4861 section 5.2 (for IPv6):
- // Once the IP address of the next-hop node is known, the sender
- // examines the Neighbor Cache for link-layer information about that
- // neighbor. If no entry exists, the sender creates one, sets its state
- // to INCOMPLETE, initiates Address Resolution, and then queues the data
- // packet pending completion of address resolution.
+ // RFC 4861 section 7.2.2 (for IPv6):
+ // While waiting for address resolution to complete, the sender MUST, for
+ // each neighbor, retain a small queue of packets waiting for address
+ // resolution to complete. The queue MUST hold at least one packet, and MAY
+ // contain more. However, the number of queued packets per neighbor SHOULD
+ // be limited to some small value. When a queue overflows, the new arrival
+ // SHOULD replace the oldest entry. Once address resolution completes, the
+ // node transmits any queued packets.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
- r := r.Clone()
+ r.Acquire()
n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
@@ -316,7 +317,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
+ routeInfo: routeInfo{
+ NetProto: protocol,
+ },
}
r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
@@ -545,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
return n.neigh.entries(), nil
}
-func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
- if n.neigh == nil {
- return
- }
-
- n.neigh.removeWaker(addr, w)
-}
-
func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
if n.neigh == nil {
return tcpip.ErrNotSupported
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index ab629b3a4..12d67409a 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -109,14 +109,6 @@ const (
//
// Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
defaultMaxReachbilityConfirmations = 3
-
- // defaultUnreachableTime is the default duration for how long an entry will
- // remain in the FAILED state before being removed from the neighbor cache.
- //
- // Note, there is no equivalent protocol constant defined in RFC 4861. It
- // leaves the specifics of any garbage collection mechanism up to the
- // implementation.
- defaultUnreachableTime = 5 * time.Second
)
// NUDDispatcher is the interface integrators of netstack must implement to
@@ -278,10 +270,6 @@ type NUDConfigurations struct {
// TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
// configuration option is necessary.
MaxReachabilityConfirmations uint32
-
- // UnreachableTime describes how long an entry will remain in the FAILED
- // state before being removed from the neighbor cache.
- UnreachableTime time.Duration
}
// DefaultNUDConfigurations returns a NUDConfigurations populated with default
@@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations {
MaxUnicastProbes: defaultMaxUnicastProbes,
MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
- UnreachableTime: defaultUnreachableTime,
}
}
@@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() {
if c.MaxUnicastProbes == 0 {
c.MaxUnicastProbes = defaultMaxUnicastProbes
}
- if c.UnreachableTime == 0 {
- c.UnreachableTime = defaultUnreachableTime
- }
}
// calcMaxRandomFactor calculates the maximum value of the random factor used
@@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration {
s.config.BaseReachableTime != s.prevBaseReachableTime ||
s.config.MinRandomFactor != s.prevMinRandomFactor ||
s.config.MaxRandomFactor != s.prevMaxRandomFactor {
- return s.recomputeReachableTimeLocked()
+ s.recomputeReachableTimeLocked()
}
return s.reachableTime
}
@@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration {
// random value gets re-computed at least once every few hours.
//
// s.mu MUST be locked for writing.
-func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+func (s *NUDState) recomputeReachableTimeLocked() {
s.prevBaseReachableTime = s.config.BaseReachableTime
s.prevMinRandomFactor = s.config.MinRandomFactor
s.prevMaxRandomFactor = s.config.MaxRandomFactor
@@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
}
s.expiration = time.Now().Add(2 * time.Hour)
- return s.reachableTime
}
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 5d364a2b0..4a3adcf33 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
for _, p := range packets {
if cancelled {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
- } else if _, err := p.route.Resolve(nil); err != nil {
+ } else if p.route.IsResolutionRequired() {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index b334e27c4..7e83b7fbb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -799,19 +798,26 @@ type LinkAddressCache interface {
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
- // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
- // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
- // registered with the network protocol, the cache attempts to resolve the address
- // and returns ErrWouldBlock. Waker is notified when address resolution is
- // complete (success or not).
+ // GetLinkAddress finds the link address corresponding to the remote address
+ // (e.g. IP -> MAC).
//
- // If address resolution is required, ErrNoLinkAddress and a notification channel is
- // returned for the top level caller to block. Channel is closed once address resolution
- // is complete (success or not).
- GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
-
- // RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ // Returns a link address for the remote address, if readily available.
+ //
+ // Returns ErrWouldBlock if the link address is not readily available, along
+ // with a notification channel for the caller to block on. Triggers address
+ // resolution asynchronously.
+ //
+ // If onResolve is provided, it will be called either immediately, if
+ // resolution is not required, or when address resolution is complete, with
+ // the resolved link address and whether resolution succeeded. After any
+ // callbacks have been called, the returned notification channel is closed.
+ //
+ // If specified, the local address must be an address local to the interface
+ // the neighbor cache belongs to. The local address is the source address of
+ // a packet prompting NUD/link address resolution.
+ //
+ // TODO(gvisor.dev/issue/5151): Don't return the link address.
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
}
// RawFactory produces endpoints for writing various types of raw packets.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index de5fe6ffe..b0251d0b4 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -31,24 +30,7 @@ import (
//
// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
- // RemoteAddress is the final destination of the route.
- RemoteAddress tcpip.Address
-
- // LocalAddress is the local address where the route starts.
- LocalAddress tcpip.Address
-
- // LocalLinkAddress is the link-layer (MAC) address of the
- // where the route starts.
- LocalLinkAddress tcpip.LinkAddress
-
- // NextHop is the next node in the path to the destination.
- NextHop tcpip.Address
-
- // NetProto is the network-layer protocol.
- NetProto tcpip.NetworkProtocolNumber
-
- // Loop controls where WritePacket should send packets.
- Loop PacketLooping
+ routeInfo
// localAddressNIC is the interface the address is associated with.
// TODO(gvisor.dev/issue/4548): Remove this field once we can query the
@@ -78,6 +60,45 @@ type Route struct {
linkRes LinkAddressResolver
}
+type routeInfo struct {
+ // RemoteAddress is the final destination of the route.
+ RemoteAddress tcpip.Address
+
+ // LocalAddress is the local address where the route starts.
+ LocalAddress tcpip.Address
+
+ // LocalLinkAddress is the link-layer (MAC) address of the
+ // where the route starts.
+ LocalLinkAddress tcpip.LinkAddress
+
+ // NextHop is the next node in the path to the destination.
+ NextHop tcpip.Address
+
+ // NetProto is the network-layer protocol.
+ NetProto tcpip.NetworkProtocolNumber
+
+ // Loop controls where WritePacket should send packets.
+ Loop PacketLooping
+}
+
+// RouteInfo contains all of Route's exported fields.
+type RouteInfo struct {
+ routeInfo
+
+ // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ RemoteLinkAddress tcpip.LinkAddress
+}
+
+// GetFields returns a RouteInfo with all of r's exported fields. This allows
+// callers to store the route's fields without retaining a reference to it.
+func (r *Route) GetFields() RouteInfo {
+ return RouteInfo{
+ routeInfo: r.routeInfo,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
+ }
+}
+
// constructAndValidateRoute validates and initializes a route. It takes
// ownership of the provided local address.
//
@@ -152,13 +173,15 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
r := &Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+ routeInfo: routeInfo{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ Loop: loop,
+ },
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
}
r.mu.Lock()
@@ -264,22 +287,21 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
-// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
-// notified when address resolution is complete (success or not).
+// Resolve attempts to resolve the link address if necessary.
//
-// If address resolution is required, ErrNoLinkAddress and a notification channel is
-// returned for the top level caller to block. Channel is closed once address resolution
-// is complete (success or not).
-//
-// The NIC r uses must not be locked.
-func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
+// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
+// waiting for ARP reply). If address resolution is required, a notification
+// channel is also returned for the caller to block on. The channel is closed
+// once address resolution is complete (successful or not). If a callback is
+// provided, it will be called when address resolution is complete, regardless
+// of success or failure.
+func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
r.mu.Lock()
- defer r.mu.Unlock()
if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
+ r.mu.Unlock()
return nil, nil
}
@@ -288,6 +310,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
r.mu.remoteLinkAddress = r.LocalLinkAddress
+ r.mu.Unlock()
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -300,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
+ // Increment the route's reference count because finishResolution retains a
+ // reference to the route and releases it when called.
+ r.acquireLocked()
+ r.mu.Unlock()
+
+ finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
+ if ok {
+ r.ResolveWith(linkAddress)
+ }
+ if afterResolve != nil {
+ afterResolve()
+ }
+ r.Release()
+ }
+
if neigh := r.outgoingNIC.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
+ _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
if err != nil {
return ch, err
}
- r.mu.remoteLinkAddress = entry.LinkAddr
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
+ _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution)
if err != nil {
return ch, err
}
- r.mu.remoteLinkAddress = linkAddr
return nil, nil
}
-// RemoveWaker removes a waker that has been added in Resolve().
-func (r *Route) RemoveWaker(waker *sleep.Waker) {
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
- }
-
- if neigh := r.outgoingNIC.neigh; neigh != nil {
- neigh.removeWaker(nextAddr, waker)
- return
- }
-
- r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
-}
-
// local returns true if the route is a local route.
func (r *Route) local() bool {
return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
@@ -419,46 +440,31 @@ func (r *Route) MTU() uint32 {
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
-// Release frees all resources associated with the route.
+// Release decrements the reference counter of the resources associated with the
+// route.
func (r *Route) Release() {
r.mu.Lock()
defer r.mu.Unlock()
- if r.mu.localAddressEndpoint != nil {
- r.mu.localAddressEndpoint.DecRef()
- r.mu.localAddressEndpoint = nil
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ ep.DecRef()
}
}
-// Clone clones the route.
-func (r *Route) Clone() *Route {
+// Acquire increments the reference counter of the resources associated with the
+// route.
+func (r *Route) Acquire() {
r.mu.RLock()
defer r.mu.RUnlock()
+ r.acquireLocked()
+}
- newRoute := &Route{
- RemoteAddress: r.RemoteAddress,
- LocalAddress: r.LocalAddress,
- LocalLinkAddress: r.LocalLinkAddress,
- NextHop: r.NextHop,
- NetProto: r.NetProto,
- Loop: r.Loop,
- localAddressNIC: r.localAddressNIC,
- outgoingNIC: r.outgoingNIC,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-
- newRoute.mu.Lock()
- defer newRoute.mu.Unlock()
- newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint
- if newRoute.mu.localAddressEndpoint != nil {
- if !newRoute.mu.localAddressEndpoint.IncRef() {
- panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress))
+func (r *Route) acquireLocked() {
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ if !ep.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
}
}
- newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress
-
- return newRoute
}
// Stack returns the instance of the Stack that owns this route.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index dc4f5b3e7..114643b03 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -29,7 +29,6 @@ import (
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -171,6 +170,9 @@ type TCPSenderState struct {
// Outstanding is the number of packets in flight.
Outstanding int
+ // SackedOut is the number of packets which have been selectively acked.
+ SackedOut int
+
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
@@ -1517,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
nic := s.nics[nicID]
if nic == nil {
@@ -1528,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve)
}
// Neighbors returns all IP to MAC address associations.
@@ -1544,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
return nic.neighbors()
}
-// RemoveWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
- if s.useNeighborCache {
- s.mu.RLock()
- nic, ok := s.nics[nicID]
- s.mu.RUnlock()
-
- if ok {
- nic.removeWaker(addr, waker)
- }
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic := s.nics[nicID]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.removeWaker(fullAddr, waker)
- }
-}
-
// AddStaticNeighbor statically associates an IP address to a MAC address.
func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
s.mu.RLock()
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 45fa62720..ef0f51f1a 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -258,6 +258,44 @@ func (a Address) Unspecified() bool {
return true
}
+// MatchingPrefix returns the matching prefix length in bits.
+//
+// Panics if b and a have different lengths.
+func (a Address) MatchingPrefix(b Address) uint8 {
+ const bitsInAByte = 8
+
+ if len(a) != len(b) {
+ panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b))
+ }
+
+ var prefix uint8
+ for i := range a {
+ aByte := a[i]
+ bByte := b[i]
+
+ if aByte == bByte {
+ prefix += bitsInAByte
+ continue
+ }
+
+ // Count the remaining matching bits in the byte from MSbit to LSBbit.
+ mask := uint8(1) << (bitsInAByte - 1)
+ for {
+ if aByte&mask == bByte&mask {
+ prefix++
+ mask >>= 1
+ continue
+ }
+
+ break
+ }
+
+ break
+ }
+
+ return prefix
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -500,6 +538,9 @@ type ControlMessages struct {
// OriginalDestinationAddress holds the original destination address
// and port of the incoming packet.
OriginalDstAddress FullAddress
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr *SockError
}
// PacketOwner is used to get UID and GID of the packet.
@@ -914,14 +955,6 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
-// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
-// should bind only on a specific NIC.
-type BindToDeviceOption NICID
-
-func (*BindToDeviceOption) isGettableSocketOption() {}
-
-func (*BindToDeviceOption) isSettableSocketOption() {}
-
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
index a2c33e1f0..c688eaff5 100644
--- a/pkg/tcpip/tcpip_state_autogen.go
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -6,6 +6,58 @@ import (
"gvisor.dev/gvisor/pkg/state"
)
+func (l *sockErrorList) StateTypeName() string {
+ return "pkg/tcpip.sockErrorList"
+}
+
+func (l *sockErrorList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *sockErrorList) beforeSave() {}
+
+func (l *sockErrorList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *sockErrorList) afterLoad() {}
+
+func (l *sockErrorList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *sockErrorEntry) StateTypeName() string {
+ return "pkg/tcpip.sockErrorEntry"
+}
+
+func (e *sockErrorEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *sockErrorEntry) beforeSave() {}
+
+func (e *sockErrorEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *sockErrorEntry) afterLoad() {}
+
+func (e *sockErrorEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
func (so *SocketOptions) StateTypeName() string {
return "pkg/tcpip.SocketOptions"
}
@@ -29,6 +81,9 @@ func (so *SocketOptions) StateFields() []string {
"delayOptionEnabled",
"corkOptionEnabled",
"receiveOriginalDstAddress",
+ "recvErrEnabled",
+ "errQueue",
+ "bindToDevice",
"linger",
}
}
@@ -54,7 +109,10 @@ func (so *SocketOptions) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(14, &so.delayOptionEnabled)
stateSinkObject.Save(15, &so.corkOptionEnabled)
stateSinkObject.Save(16, &so.receiveOriginalDstAddress)
- stateSinkObject.Save(17, &so.linger)
+ stateSinkObject.Save(17, &so.recvErrEnabled)
+ stateSinkObject.Save(18, &so.errQueue)
+ stateSinkObject.Save(19, &so.bindToDevice)
+ stateSinkObject.Save(20, &so.linger)
}
func (so *SocketOptions) afterLoad() {}
@@ -77,7 +135,60 @@ func (so *SocketOptions) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(14, &so.delayOptionEnabled)
stateSourceObject.Load(15, &so.corkOptionEnabled)
stateSourceObject.Load(16, &so.receiveOriginalDstAddress)
- stateSourceObject.Load(17, &so.linger)
+ stateSourceObject.Load(17, &so.recvErrEnabled)
+ stateSourceObject.Load(18, &so.errQueue)
+ stateSourceObject.Load(19, &so.bindToDevice)
+ stateSourceObject.Load(20, &so.linger)
+}
+
+func (s *SockError) StateTypeName() string {
+ return "pkg/tcpip.SockError"
+}
+
+func (s *SockError) StateFields() []string {
+ return []string{
+ "sockErrorEntry",
+ "Err",
+ "ErrOrigin",
+ "ErrType",
+ "ErrCode",
+ "ErrInfo",
+ "Payload",
+ "Dst",
+ "Offender",
+ "NetProto",
+ }
+}
+
+func (s *SockError) beforeSave() {}
+
+func (s *SockError) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.sockErrorEntry)
+ stateSinkObject.Save(1, &s.Err)
+ stateSinkObject.Save(2, &s.ErrOrigin)
+ stateSinkObject.Save(3, &s.ErrType)
+ stateSinkObject.Save(4, &s.ErrCode)
+ stateSinkObject.Save(5, &s.ErrInfo)
+ stateSinkObject.Save(6, &s.Payload)
+ stateSinkObject.Save(7, &s.Dst)
+ stateSinkObject.Save(8, &s.Offender)
+ stateSinkObject.Save(9, &s.NetProto)
+}
+
+func (s *SockError) afterLoad() {}
+
+func (s *SockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.sockErrorEntry)
+ stateSourceObject.Load(1, &s.Err)
+ stateSourceObject.Load(2, &s.ErrOrigin)
+ stateSourceObject.Load(3, &s.ErrType)
+ stateSourceObject.Load(4, &s.ErrCode)
+ stateSourceObject.Load(5, &s.ErrInfo)
+ stateSourceObject.Load(6, &s.Payload)
+ stateSourceObject.Load(7, &s.Dst)
+ stateSourceObject.Load(8, &s.Offender)
+ stateSourceObject.Load(9, &s.NetProto)
}
func (e *Error) StateTypeName() string {
@@ -153,6 +264,7 @@ func (c *ControlMessages) StateFields() []string {
"PacketInfo",
"HasOriginalDstAddress",
"OriginalDstAddress",
+ "SockErr",
}
}
@@ -172,6 +284,7 @@ func (c *ControlMessages) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(9, &c.PacketInfo)
stateSinkObject.Save(10, &c.HasOriginalDstAddress)
stateSinkObject.Save(11, &c.OriginalDstAddress)
+ stateSinkObject.Save(12, &c.SockErr)
}
func (c *ControlMessages) afterLoad() {}
@@ -189,6 +302,7 @@ func (c *ControlMessages) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(9, &c.PacketInfo)
stateSourceObject.Load(10, &c.HasOriginalDstAddress)
stateSourceObject.Load(11, &c.OriginalDstAddress)
+ stateSourceObject.Load(12, &c.SockErr)
}
func (l *LinkPacketInfo) StateTypeName() string {
@@ -273,7 +387,10 @@ func (i *IPPacketInfo) StateLoad(stateSourceObject state.Source) {
}
func init() {
+ state.Register((*sockErrorList)(nil))
+ state.Register((*sockErrorEntry)(nil))
state.Register((*SocketOptions)(nil))
+ state.Register((*SockError)(nil))
state.Register((*Error)(nil))
state.Register((*FullAddress)(nil))
state.Register((*ControlMessages)(nil))
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 74fe19e98..d1e4a7cb7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -504,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
@@ -519,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
e.ID = id
- e.route = r.Clone()
+ e.route = r
e.RegisterNICID = nicID
e.state = stateConnected
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 9faab4b9e..e5e247342 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -366,6 +366,13 @@ func (ep *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (ep *endpoint) UpdateLastError(err *tcpip.Error) {
+ ep.lastErrorMu.Lock()
+ ep.lastError = err
+ ep.lastErrorMu.Unlock()
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
return tcpip.ErrNotSupported
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index eee3f11c1..7befcfc9b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -261,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
e.mu.RLock()
+ defer e.mu.RUnlock()
if e.closed {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
payloadBytes, err := p.FullPayload()
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
@@ -278,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -300,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if e.route.IsResolutionRequired() {
- savedRoute := e.route
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in finishWrite.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Make sure that the route didn't change during the
- // time we didn't hold the lock.
- if !e.connected || savedRoute != e.route {
- e.mu.Unlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
- n, ch, err := e.finishWrite(payloadBytes, savedRoute)
- e.mu.Unlock()
- return n, ch, err
- }
-
- n, ch, err := e.finishWrite(payloadBytes, e.route)
- e.mu.RUnlock()
- return n, ch, err
+ return e.finishWrite(payloadBytes, e.route)
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
@@ -340,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
- e.mu.RUnlock()
return n, ch, err
}
@@ -404,7 +377,7 @@ func (*endpoint) Disconnect() *tcpip.Error {
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
- return tcpip.ErrInvalidOptionValue
+ return tcpip.ErrAddressFamilyNotSupported
}
e.mu.Lock()
@@ -435,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer route.Release()
if e.associated {
// Re-register the endpoint with the appropriate NIC.
if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ route.Release()
return err
}
e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
@@ -447,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the route we've connected via.
- e.route = route.Clone()
+ e.route = route
e.connected = true
return nil
@@ -620,6 +593,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ e.mu.RLock()
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -632,6 +606,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// sockets.
if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
return
@@ -639,6 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
@@ -650,11 +626,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// If bound to a NIC, only accept data for that NIC.
if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
// If bound to an address, only accept data for that address.
if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
}
@@ -663,6 +641,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// connected to.
if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
@@ -697,6 +676,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
if wasEmpty {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3e1041cbe..2d96a65bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index c944dccc0..0dc710276 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error {
func (h *handshake) resolveRoute() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error {
// Initial action is to resolve route.
index := wakerForResolution
+ attemptedResolution := false
for {
switch index {
case wakerForResolution:
- if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
- if err == tcpip.ErrNoLinkAddress {
- h.ep.stats.SendErrors.NoLinkAddr.Increment()
- } else if err != nil {
+ if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock {
+ if err != nil {
h.ep.stats.SendErrors.NoRoute.Increment()
}
// Either success (err == nil) or failure.
return err
}
+ if attemptedResolution {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ return tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
// Resolution not completed. Keep trying...
case wakerForNotification:
n := h.ep.fetchNotifications()
if n&notifyClose != 0 {
- h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error {
// complete completes the TCP 3-way handshake initiated by h.start().
func (h *handshake) complete() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resendWaker := sleep.Waker{}
s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Initialize the sleeper based on the wakers in funcs.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
for i := range funcs {
s.AddWaker(funcs[i].w, i)
}
@@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const notification = 2
const timeWaitDone = 3
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7a37c10bb..6e3c8860e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -502,9 +502,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // bindToDevice is set to the NIC on which to bind or disabled if 0.
- bindToDevice tcpip.NICID
-
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -1303,6 +1300,15 @@ func (e *endpoint) LastError() *tcpip.Error {
return e.lastErrorLocked()
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.LockUser()
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+ e.UnlockUser()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1812,18 +1818,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.LockUser()
- e.bindToDevice = id
- e.UnlockUser()
-
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(*v)
@@ -2004,11 +2005,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case *tcpip.BindToDeviceOption:
- e.LockUser()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.UnlockUser()
-
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.LockUser()
@@ -2211,11 +2207,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
}
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
if err != tcpip.ErrPortInUse || !reuse {
return false, nil
}
@@ -2253,15 +2250,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
return false, nil
}
}
id := e.ID
id.LocalPort = p
- if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
if err == tcpip.ErrPortInUse {
return false, nil
}
@@ -2272,7 +2269,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// the selected port.
e.ID = id
e.isPortReserved = true
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
e.boundDest = addr
return true, nil
@@ -2283,7 +2280,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.isRegistered = true
e.setEndpointState(StateConnecting)
- e.route = r.Clone()
+ r.Acquire()
+ e.route = r
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -2624,7 +2622,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
id := e.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2635,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// demuxer. Further connected endpoints always have a remote
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
- if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
return false
}
return true
@@ -2644,7 +2643,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
// TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
e.boundNICID = nic
@@ -2708,6 +2707,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
return true
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ // Linux passes the payload with the TCP header. We don't know if the TCP
+ // header even exists, it may not for fragmented packets.
+ Payload: pkt.Data.ToView(),
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.notifyProtocolGoroutine(notifyError)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
@@ -2722,16 +2756,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.notifyProtocolGoroutine(notifyMTUChanged)
case stack.ControlNoRoute:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNoRoute
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
case stack.ControlNetworkUnreachable:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNetworkUnreachable
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
}
@@ -2989,6 +3017,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
Ssthresh: e.snd.sndSsthresh,
SndCAAckCount: e.snd.sndCAAckCount,
Outstanding: e.snd.outstanding,
+ SackedOut: e.snd.sackedOut,
SndWnd: e.snd.sndWnd,
SndUna: e.snd.sndUna,
SndNxt: e.snd.sndNxt,
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index f2b1b68da..405a6dce7 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// If we started off with a window larger than what can he held in
// the 16bit window field, we ceil the value to the max value.
- // While ceiling, we still do not want to grow the right edge when
- // not applicable.
if scaledWnd > math.MaxUint16 {
- if toGrow {
- scaledWnd = seqnum.Size(math.MaxUint16)
- } else {
- scaledWnd = seqnum.Size(uint16(scaledWnd))
- }
+ scaledWnd = seqnum.Size(math.MaxUint16)
+
+ // Ensure that the stashed receive window always reflects what
+ // is being advertised.
+ r.rcvWnd = scaledWnd << r.rcvWndScale
}
return r.rcvNxt, scaledWnd
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index baec762e1..cc991aba6 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -137,6 +137,9 @@ type sender struct {
// that have been sent but not yet acknowledged.
outstanding int
+ // sackedOut is the number of packets which are selectively acked.
+ sackedOut int
+
// sndWnd is the send window size.
sndWnd seqnum.Size
@@ -372,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
+ oldMSS := s.maxPayloadSize
s.maxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
@@ -394,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
+ nextSeg := s.writeNext
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
@@ -401,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
break
}
- if seg.data.Size() > m {
+ if nextSeg == s.writeNext && seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
- s.writeNext = seg
- break
+ nextSeg = seg
+ }
+
+ if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Update sackedOut for new maximum payload size.
+ s.sackedOut -= s.pCount(seg, oldMSS)
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
+ s.writeNext = nextSeg
s.sendData()
}
@@ -629,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool {
// pCount returns the number of packets in the segment. Due to GSO, a segment
// can be composed of multiple packets.
-func (s *sender) pCount(seg *segment) int {
+func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
size := seg.data.Size()
if size == 0 {
return 1
}
- return (size-1)/s.maxPayloadSize + 1
+ return (size-1)/maxPayloadSize + 1
}
// splitSeg splits a given segment at the size specified and inserts the
@@ -1023,7 +1034,7 @@ func (s *sender) sendData() {
break
}
dataSent = true
- s.outstanding += s.pCount(seg)
+ s.outstanding += s.pCount(seg, s.maxPayloadSize)
s.writeNext = seg.Next()
}
@@ -1038,6 +1049,7 @@ func (s *sender) enterRecovery() {
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
+ s.sackedOut = 0
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
@@ -1207,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
s.rc.detectReorder(seg)
seg.acked = true
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
seg = seg.Next()
}
@@ -1380,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg)
+ prevCount := s.pCount(seg, s.maxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg)
+ s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
break
}
@@ -1399,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.writeList.Remove(seg)
- // If SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg)
+ s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ } else {
+ s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 8eba0efeb..5922083a9 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -187,7 +187,6 @@ func (e *endpoint) StateFields() []string {
"shutdownFlags",
"sackPermitted",
"sack",
- "bindToDevice",
"delay",
"scoreboard",
"segmentQueue",
@@ -232,7 +231,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
var recentTSTimeValue unixTime = e.saveRecentTSTime()
stateSinkObject.SaveValue(26, recentTSTimeValue)
var acceptedChanValue []*endpoint = e.saveAcceptedChan()
- stateSinkObject.SaveValue(50, acceptedChanValue)
+ stateSinkObject.SaveValue(49, acceptedChanValue)
stateSinkObject.Save(0, &e.EndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
@@ -260,36 +259,35 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(28, &e.shutdownFlags)
stateSinkObject.Save(29, &e.sackPermitted)
stateSinkObject.Save(30, &e.sack)
- stateSinkObject.Save(31, &e.bindToDevice)
- stateSinkObject.Save(32, &e.delay)
- stateSinkObject.Save(33, &e.scoreboard)
- stateSinkObject.Save(34, &e.segmentQueue)
- stateSinkObject.Save(35, &e.synRcvdCount)
- stateSinkObject.Save(36, &e.userMSS)
- stateSinkObject.Save(37, &e.maxSynRetries)
- stateSinkObject.Save(38, &e.windowClamp)
- stateSinkObject.Save(39, &e.sndBufSize)
- stateSinkObject.Save(40, &e.sndBufUsed)
- stateSinkObject.Save(41, &e.sndClosed)
- stateSinkObject.Save(42, &e.sndBufInQueue)
- stateSinkObject.Save(43, &e.sndQueue)
- stateSinkObject.Save(44, &e.cc)
- stateSinkObject.Save(45, &e.packetTooBigCount)
- stateSinkObject.Save(46, &e.sndMTU)
- stateSinkObject.Save(47, &e.keepalive)
- stateSinkObject.Save(48, &e.userTimeout)
- stateSinkObject.Save(49, &e.deferAccept)
- stateSinkObject.Save(51, &e.rcv)
- stateSinkObject.Save(52, &e.snd)
- stateSinkObject.Save(53, &e.connectingAddress)
- stateSinkObject.Save(54, &e.amss)
- stateSinkObject.Save(55, &e.sendTOS)
- stateSinkObject.Save(56, &e.gso)
- stateSinkObject.Save(57, &e.tcpLingerTimeout)
- stateSinkObject.Save(58, &e.closed)
- stateSinkObject.Save(59, &e.txHash)
- stateSinkObject.Save(60, &e.owner)
- stateSinkObject.Save(61, &e.ops)
+ stateSinkObject.Save(31, &e.delay)
+ stateSinkObject.Save(32, &e.scoreboard)
+ stateSinkObject.Save(33, &e.segmentQueue)
+ stateSinkObject.Save(34, &e.synRcvdCount)
+ stateSinkObject.Save(35, &e.userMSS)
+ stateSinkObject.Save(36, &e.maxSynRetries)
+ stateSinkObject.Save(37, &e.windowClamp)
+ stateSinkObject.Save(38, &e.sndBufSize)
+ stateSinkObject.Save(39, &e.sndBufUsed)
+ stateSinkObject.Save(40, &e.sndClosed)
+ stateSinkObject.Save(41, &e.sndBufInQueue)
+ stateSinkObject.Save(42, &e.sndQueue)
+ stateSinkObject.Save(43, &e.cc)
+ stateSinkObject.Save(44, &e.packetTooBigCount)
+ stateSinkObject.Save(45, &e.sndMTU)
+ stateSinkObject.Save(46, &e.keepalive)
+ stateSinkObject.Save(47, &e.userTimeout)
+ stateSinkObject.Save(48, &e.deferAccept)
+ stateSinkObject.Save(50, &e.rcv)
+ stateSinkObject.Save(51, &e.snd)
+ stateSinkObject.Save(52, &e.connectingAddress)
+ stateSinkObject.Save(53, &e.amss)
+ stateSinkObject.Save(54, &e.sendTOS)
+ stateSinkObject.Save(55, &e.gso)
+ stateSinkObject.Save(56, &e.tcpLingerTimeout)
+ stateSinkObject.Save(57, &e.closed)
+ stateSinkObject.Save(58, &e.txHash)
+ stateSinkObject.Save(59, &e.owner)
+ stateSinkObject.Save(60, &e.ops)
}
func (e *endpoint) StateLoad(stateSourceObject state.Source) {
@@ -320,41 +318,40 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(28, &e.shutdownFlags)
stateSourceObject.Load(29, &e.sackPermitted)
stateSourceObject.Load(30, &e.sack)
- stateSourceObject.Load(31, &e.bindToDevice)
- stateSourceObject.Load(32, &e.delay)
- stateSourceObject.Load(33, &e.scoreboard)
- stateSourceObject.LoadWait(34, &e.segmentQueue)
- stateSourceObject.Load(35, &e.synRcvdCount)
- stateSourceObject.Load(36, &e.userMSS)
- stateSourceObject.Load(37, &e.maxSynRetries)
- stateSourceObject.Load(38, &e.windowClamp)
- stateSourceObject.Load(39, &e.sndBufSize)
- stateSourceObject.Load(40, &e.sndBufUsed)
- stateSourceObject.Load(41, &e.sndClosed)
- stateSourceObject.Load(42, &e.sndBufInQueue)
- stateSourceObject.LoadWait(43, &e.sndQueue)
- stateSourceObject.Load(44, &e.cc)
- stateSourceObject.Load(45, &e.packetTooBigCount)
- stateSourceObject.Load(46, &e.sndMTU)
- stateSourceObject.Load(47, &e.keepalive)
- stateSourceObject.Load(48, &e.userTimeout)
- stateSourceObject.Load(49, &e.deferAccept)
- stateSourceObject.LoadWait(51, &e.rcv)
- stateSourceObject.LoadWait(52, &e.snd)
- stateSourceObject.Load(53, &e.connectingAddress)
- stateSourceObject.Load(54, &e.amss)
- stateSourceObject.Load(55, &e.sendTOS)
- stateSourceObject.Load(56, &e.gso)
- stateSourceObject.Load(57, &e.tcpLingerTimeout)
- stateSourceObject.Load(58, &e.closed)
- stateSourceObject.Load(59, &e.txHash)
- stateSourceObject.Load(60, &e.owner)
- stateSourceObject.Load(61, &e.ops)
+ stateSourceObject.Load(31, &e.delay)
+ stateSourceObject.Load(32, &e.scoreboard)
+ stateSourceObject.LoadWait(33, &e.segmentQueue)
+ stateSourceObject.Load(34, &e.synRcvdCount)
+ stateSourceObject.Load(35, &e.userMSS)
+ stateSourceObject.Load(36, &e.maxSynRetries)
+ stateSourceObject.Load(37, &e.windowClamp)
+ stateSourceObject.Load(38, &e.sndBufSize)
+ stateSourceObject.Load(39, &e.sndBufUsed)
+ stateSourceObject.Load(40, &e.sndClosed)
+ stateSourceObject.Load(41, &e.sndBufInQueue)
+ stateSourceObject.LoadWait(42, &e.sndQueue)
+ stateSourceObject.Load(43, &e.cc)
+ stateSourceObject.Load(44, &e.packetTooBigCount)
+ stateSourceObject.Load(45, &e.sndMTU)
+ stateSourceObject.Load(46, &e.keepalive)
+ stateSourceObject.Load(47, &e.userTimeout)
+ stateSourceObject.Load(48, &e.deferAccept)
+ stateSourceObject.LoadWait(50, &e.rcv)
+ stateSourceObject.LoadWait(51, &e.snd)
+ stateSourceObject.Load(52, &e.connectingAddress)
+ stateSourceObject.Load(53, &e.amss)
+ stateSourceObject.Load(54, &e.sendTOS)
+ stateSourceObject.Load(55, &e.gso)
+ stateSourceObject.Load(56, &e.tcpLingerTimeout)
+ stateSourceObject.Load(57, &e.closed)
+ stateSourceObject.Load(58, &e.txHash)
+ stateSourceObject.Load(59, &e.owner)
+ stateSourceObject.Load(60, &e.ops)
stateSourceObject.LoadValue(4, new(string), func(y interface{}) { e.loadHardError(y.(string)) })
stateSourceObject.LoadValue(5, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) })
stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) })
- stateSourceObject.LoadValue(50, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) })
+ stateSourceObject.LoadValue(49, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) })
stateSourceObject.AfterLoad(e.afterLoad)
}
@@ -724,6 +721,7 @@ func (s *sender) StateFields() []string {
"sndSsthresh",
"sndCAAckCount",
"outstanding",
+ "sackedOut",
"sndWnd",
"sndUna",
"sndNxt",
@@ -755,9 +753,9 @@ func (s *sender) StateSave(stateSinkObject state.Sink) {
var lastSendTimeValue unixTime = s.saveLastSendTime()
stateSinkObject.SaveValue(1, lastSendTimeValue)
var rttMeasureTimeValue unixTime = s.saveRttMeasureTime()
- stateSinkObject.SaveValue(13, rttMeasureTimeValue)
+ stateSinkObject.SaveValue(14, rttMeasureTimeValue)
var firstRetransmittedSegXmitTimeValue unixTime = s.saveFirstRetransmittedSegXmitTime()
- stateSinkObject.SaveValue(14, firstRetransmittedSegXmitTimeValue)
+ stateSinkObject.SaveValue(15, firstRetransmittedSegXmitTimeValue)
stateSinkObject.Save(0, &s.ep)
stateSinkObject.Save(2, &s.dupAckCount)
stateSinkObject.Save(3, &s.fr)
@@ -766,25 +764,26 @@ func (s *sender) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(6, &s.sndSsthresh)
stateSinkObject.Save(7, &s.sndCAAckCount)
stateSinkObject.Save(8, &s.outstanding)
- stateSinkObject.Save(9, &s.sndWnd)
- stateSinkObject.Save(10, &s.sndUna)
- stateSinkObject.Save(11, &s.sndNxt)
- stateSinkObject.Save(12, &s.rttMeasureSeqNum)
- stateSinkObject.Save(15, &s.closed)
- stateSinkObject.Save(16, &s.writeNext)
- stateSinkObject.Save(17, &s.writeList)
- stateSinkObject.Save(18, &s.rtt)
- stateSinkObject.Save(19, &s.rto)
- stateSinkObject.Save(20, &s.minRTO)
- stateSinkObject.Save(21, &s.maxRTO)
- stateSinkObject.Save(22, &s.maxRetries)
- stateSinkObject.Save(23, &s.maxPayloadSize)
- stateSinkObject.Save(24, &s.gso)
- stateSinkObject.Save(25, &s.sndWndScale)
- stateSinkObject.Save(26, &s.maxSentAck)
- stateSinkObject.Save(27, &s.state)
- stateSinkObject.Save(28, &s.cc)
- stateSinkObject.Save(29, &s.rc)
+ stateSinkObject.Save(9, &s.sackedOut)
+ stateSinkObject.Save(10, &s.sndWnd)
+ stateSinkObject.Save(11, &s.sndUna)
+ stateSinkObject.Save(12, &s.sndNxt)
+ stateSinkObject.Save(13, &s.rttMeasureSeqNum)
+ stateSinkObject.Save(16, &s.closed)
+ stateSinkObject.Save(17, &s.writeNext)
+ stateSinkObject.Save(18, &s.writeList)
+ stateSinkObject.Save(19, &s.rtt)
+ stateSinkObject.Save(20, &s.rto)
+ stateSinkObject.Save(21, &s.minRTO)
+ stateSinkObject.Save(22, &s.maxRTO)
+ stateSinkObject.Save(23, &s.maxRetries)
+ stateSinkObject.Save(24, &s.maxPayloadSize)
+ stateSinkObject.Save(25, &s.gso)
+ stateSinkObject.Save(26, &s.sndWndScale)
+ stateSinkObject.Save(27, &s.maxSentAck)
+ stateSinkObject.Save(28, &s.state)
+ stateSinkObject.Save(29, &s.cc)
+ stateSinkObject.Save(30, &s.rc)
}
func (s *sender) StateLoad(stateSourceObject state.Source) {
@@ -796,28 +795,29 @@ func (s *sender) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(6, &s.sndSsthresh)
stateSourceObject.Load(7, &s.sndCAAckCount)
stateSourceObject.Load(8, &s.outstanding)
- stateSourceObject.Load(9, &s.sndWnd)
- stateSourceObject.Load(10, &s.sndUna)
- stateSourceObject.Load(11, &s.sndNxt)
- stateSourceObject.Load(12, &s.rttMeasureSeqNum)
- stateSourceObject.Load(15, &s.closed)
- stateSourceObject.Load(16, &s.writeNext)
- stateSourceObject.Load(17, &s.writeList)
- stateSourceObject.Load(18, &s.rtt)
- stateSourceObject.Load(19, &s.rto)
- stateSourceObject.Load(20, &s.minRTO)
- stateSourceObject.Load(21, &s.maxRTO)
- stateSourceObject.Load(22, &s.maxRetries)
- stateSourceObject.Load(23, &s.maxPayloadSize)
- stateSourceObject.Load(24, &s.gso)
- stateSourceObject.Load(25, &s.sndWndScale)
- stateSourceObject.Load(26, &s.maxSentAck)
- stateSourceObject.Load(27, &s.state)
- stateSourceObject.Load(28, &s.cc)
- stateSourceObject.Load(29, &s.rc)
+ stateSourceObject.Load(9, &s.sackedOut)
+ stateSourceObject.Load(10, &s.sndWnd)
+ stateSourceObject.Load(11, &s.sndUna)
+ stateSourceObject.Load(12, &s.sndNxt)
+ stateSourceObject.Load(13, &s.rttMeasureSeqNum)
+ stateSourceObject.Load(16, &s.closed)
+ stateSourceObject.Load(17, &s.writeNext)
+ stateSourceObject.Load(18, &s.writeList)
+ stateSourceObject.Load(19, &s.rtt)
+ stateSourceObject.Load(20, &s.rto)
+ stateSourceObject.Load(21, &s.minRTO)
+ stateSourceObject.Load(22, &s.maxRTO)
+ stateSourceObject.Load(23, &s.maxRetries)
+ stateSourceObject.Load(24, &s.maxPayloadSize)
+ stateSourceObject.Load(25, &s.gso)
+ stateSourceObject.Load(26, &s.sndWndScale)
+ stateSourceObject.Load(27, &s.maxSentAck)
+ stateSourceObject.Load(28, &s.state)
+ stateSourceObject.Load(29, &s.cc)
+ stateSourceObject.Load(30, &s.rc)
stateSourceObject.LoadValue(1, new(unixTime), func(y interface{}) { s.loadLastSendTime(y.(unixTime)) })
- stateSourceObject.LoadValue(13, new(unixTime), func(y interface{}) { s.loadRttMeasureTime(y.(unixTime)) })
- stateSourceObject.LoadValue(14, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(14, new(unixTime), func(y interface{}) { s.loadRttMeasureTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(15, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) })
stateSourceObject.AfterLoad(s.afterLoad)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 763d1d654..9b9e4deb0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -109,7 +109,6 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
portFlags ports.Flags
- bindToDevice tcpip.NICID
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -226,6 +225,13 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+}
+
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
@@ -511,6 +517,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
+ so := e.SocketOptions()
+ if so.GetRecvError() {
+ so.QueueLocalErr(
+ tcpip.ErrMessageTooLong,
+ route.NetProto,
+ header.UDPMaximumPacketSize,
+ tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ Port: dstPort,
+ },
+ v,
+ )
+ }
return 0, nil, tcpip.ErrMessageTooLong
}
@@ -638,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
@@ -754,15 +778,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
delete(e.multicastMemberships, memToRemove)
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.mu.Lock()
- e.bindToDevice = id
- e.mu.Unlock()
-
case *tcpip.SocketDetachFilterOption:
return nil
}
@@ -838,11 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
e.mu.Unlock()
- case *tcpip.BindToDeviceOption:
- e.mu.RLock()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -996,7 +1006,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: e.ID.LocalAddress,
@@ -1024,6 +1033,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
@@ -1034,7 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
- e.route = r.Clone()
+ e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1092,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
id.LocalPort = port
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@@ -1259,6 +1270,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1267,10 +1279,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
- // packets at "Parse" instead of when handling a packet.
- pkt.Data.CapLength(int(hdr.PayloadLength()))
-
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1304,7 +1312,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID,
@@ -1341,15 +1349,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ // Linux passes the payload without the UDP header.
+ var payload []byte
+ udp := header.UDP(pkt.Data.ToView())
+ if len(udp) >= header.UDPMinimumSize {
+ payload = udp.Payload()
+ }
+
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ Payload: payload,
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.waiterQueue.Notify(waiter.EventErr)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
if e.EndpointState() == StateConnected {
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrConnectionRefused
- e.lastErrorMu.Unlock()
-
- e.waiterQueue.Notify(waiter.EventErr)
+ var errType byte
+ var errCode byte
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ errType = byte(header.ICMPv4DstUnreachable)
+ errCode = byte(header.ICMPv4PortUnreachable)
+ case header.IPv6ProtocolNumber:
+ errType = byte(header.ICMPv6DstUnreachable)
+ errCode = byte(header.ICMPv6PortUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
+ }
+ e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt)
return
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 14e4648cd..d7fc21f11 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
route.ResolveWith(r.pkt.SourceLinkAddress())
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
index ec0a8c902..2b7726097 100644
--- a/pkg/tcpip/transport/udp/udp_state_autogen.go
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -73,7 +73,6 @@ func (e *endpoint) StateFields() []string {
"multicastAddr",
"multicastNICID",
"portFlags",
- "bindToDevice",
"lastError",
"boundBindToDevice",
"boundPortFlags",
@@ -91,7 +90,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
var lastErrorValue string = e.saveLastError()
- stateSinkObject.SaveValue(19, lastErrorValue)
+ stateSinkObject.SaveValue(18, lastErrorValue)
stateSinkObject.Save(0, &e.TransportEndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
@@ -109,15 +108,14 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(15, &e.multicastAddr)
stateSinkObject.Save(16, &e.multicastNICID)
stateSinkObject.Save(17, &e.portFlags)
- stateSinkObject.Save(18, &e.bindToDevice)
- stateSinkObject.Save(20, &e.boundBindToDevice)
- stateSinkObject.Save(21, &e.boundPortFlags)
- stateSinkObject.Save(22, &e.sendTOS)
- stateSinkObject.Save(23, &e.shutdownFlags)
- stateSinkObject.Save(24, &e.multicastMemberships)
- stateSinkObject.Save(25, &e.effectiveNetProtos)
- stateSinkObject.Save(26, &e.owner)
- stateSinkObject.Save(27, &e.ops)
+ stateSinkObject.Save(19, &e.boundBindToDevice)
+ stateSinkObject.Save(20, &e.boundPortFlags)
+ stateSinkObject.Save(21, &e.sendTOS)
+ stateSinkObject.Save(22, &e.shutdownFlags)
+ stateSinkObject.Save(23, &e.multicastMemberships)
+ stateSinkObject.Save(24, &e.effectiveNetProtos)
+ stateSinkObject.Save(25, &e.owner)
+ stateSinkObject.Save(26, &e.ops)
}
func (e *endpoint) StateLoad(stateSourceObject state.Source) {
@@ -138,17 +136,16 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(15, &e.multicastAddr)
stateSourceObject.Load(16, &e.multicastNICID)
stateSourceObject.Load(17, &e.portFlags)
- stateSourceObject.Load(18, &e.bindToDevice)
- stateSourceObject.Load(20, &e.boundBindToDevice)
- stateSourceObject.Load(21, &e.boundPortFlags)
- stateSourceObject.Load(22, &e.sendTOS)
- stateSourceObject.Load(23, &e.shutdownFlags)
- stateSourceObject.Load(24, &e.multicastMemberships)
- stateSourceObject.Load(25, &e.effectiveNetProtos)
- stateSourceObject.Load(26, &e.owner)
- stateSourceObject.Load(27, &e.ops)
+ stateSourceObject.Load(19, &e.boundBindToDevice)
+ stateSourceObject.Load(20, &e.boundPortFlags)
+ stateSourceObject.Load(21, &e.sendTOS)
+ stateSourceObject.Load(22, &e.shutdownFlags)
+ stateSourceObject.Load(23, &e.multicastMemberships)
+ stateSourceObject.Load(24, &e.effectiveNetProtos)
+ stateSourceObject.Load(25, &e.owner)
+ stateSourceObject.Load(26, &e.ops)
stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
- stateSourceObject.LoadValue(19, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
+ stateSourceObject.LoadValue(18, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
stateSourceObject.AfterLoad(e.afterLoad)
}