summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go112
-rw-r--r--pkg/tcpip/checker/checker.go15
-rw-r--r--pkg/tcpip/header/eth.go6
-rw-r--r--pkg/tcpip/header/eth_test.go4
-rw-r--r--pkg/tcpip/internal/tcp/BUILD12
-rw-r--r--pkg/tcpip/internal/tcp/tcp.go48
-rw-r--r--pkg/tcpip/link/channel/channel.go24
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go28
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go3
-rw-r--r--pkg/tcpip/link/loopback/loopback.go31
-rw-r--r--pkg/tcpip/link/muxed/injectable.go5
-rw-r--r--pkg/tcpip/link/nested/nested.go15
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/pipe/pipe.go3
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go10
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s2
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s2
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go32
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD30
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go2
-rw-r--r--pkg/tcpip/link/sharedmem/queuepair.go199
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go30
-rw-r--r--pkg/tcpip/link/sharedmem/server_rx.go142
-rw-r--r--pkg/tcpip/link/sharedmem/server_tx.go175
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go233
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server.go333
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server_test.go220
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go103
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go20
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/link/tun/BUILD1
-rw-r--r--pkg/tcpip/link/tun/device.go5
-rw-r--r--pkg/tcpip/link/waitable/waitable.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go5
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go5
-rw-r--r--pkg/tcpip/network/ip_test.go175
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go105
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go28
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go85
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go218
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go91
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go71
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go100
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go243
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go33
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go48
-rw-r--r--pkg/tcpip/network/multicast_group_test.go21
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go8
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go16
-rw-r--r--pkg/tcpip/socketops.go29
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go24
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go4
-rw-r--r--pkg/tcpip/stack/conntrack.go557
-rw-r--r--pkg/tcpip/stack/forwarding_test.go30
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go39
-rw-r--r--pkg/tcpip/stack/iptables.go198
-rw-r--r--pkg/tcpip/stack/iptables_state.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go150
-rw-r--r--pkg/tcpip/stack/iptables_types.go28
-rw-r--r--pkg/tcpip/stack/ndp_test.go107
-rw-r--r--pkg/tcpip/stack/nic.go148
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go70
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go12
-rw-r--r--pkg/tcpip/stack/registration.go47
-rw-r--r--pkg/tcpip/stack/stack.go122
-rw-r--r--pkg/tcpip/stack/stack_test.go624
-rw-r--r--pkg/tcpip/stack/tcp.go15
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go50
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go18
-rw-r--r--pkg/tcpip/stack/transport_test.go33
-rw-r--r--pkg/tcpip/tcpip.go54
-rw-r--r--pkg/tcpip/tcpip_state.go27
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go20
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go386
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go24
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go47
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go20
-rw-r--r--pkg/tcpip/tests/integration/route_test.go38
-rw-r--r--pkg/tcpip/tests/utils/utils.go68
-rw-r--r--pkg/tcpip/transport/BUILD13
-rw-r--r--pkg/tcpip/transport/datagram.go49
-rw-r--r--pkg/tcpip/transport/icmp/BUILD2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go437
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go35
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go8
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD46
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go811
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go58
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go318
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go252
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go24
-rw-r--r--pkg/tcpip/transport/raw/BUILD2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go436
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go30
-rw-r--r--pkg/tcpip/transport/tcp/BUILD15
-rw-r--r--pkg/tcpip/transport/tcp/accept.go399
-rw-r--r--pkg/tcpip/transport/tcp/connect.go133
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go218
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go9
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go51
-rw-r--r--pkg/tcpip/transport/tcp/rack.go3
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go141
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go24
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go265
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go684
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go89
-rw-r--r--pkg/tcpip/transport/transport.go16
-rw-r--r--pkg/tcpip/transport/udp/BUILD3
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go903
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go65
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go21
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go161
127 files changed, 8341 insertions, 3607 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index f00cfd0f5..b98de54c5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -25,6 +25,7 @@ go_library(
"stdclock.go",
"stdclock_state.go",
"tcpip.go",
+ "tcpip_state.go",
"timer.go",
],
visibility = ["//visibility:public"],
@@ -69,7 +70,6 @@ deps_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/packetsocket",
"//pkg/tcpip/link/qdisc/fifo",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 48b24692b..c8460e63c 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
done := make(chan struct{})
@@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
done := make(chan struct{})
fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
@@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %w", NICID, protocolAddr, err)
+ }
l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ return nil, nil, nil, fmt.Errorf("NewListener: %w", err)
}
c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ return nil, nil, nil, fmt.Errorf("DialTCP: %w", err)
}
c2, err = l.Accept()
if err != nil {
l.Close()
c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Accept: %w", err)
}
stop = func() {
@@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
if err := l.Close(); err != nil {
stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Close(): %w", err)
}
return c1, c2, stop, nil
@@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
@@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
time.Sleep(time.Second)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index e0dfe5813..24c2c3e6b 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field
+// in ControlMessages.
+func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPv6PacketInfo {
+ t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo)
+ } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" {
+ t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
// field in ControlMessages.
func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
@@ -729,7 +742,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
return
}
l := int(opts[i+1])
- if i < 2 || i+l > limit {
+ if l < 2 || i+l > limit {
return
}
i += l
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 95ade0e5c..1f18213e5 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -49,9 +49,9 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
- // unspecifiedEthernetAddress is the unspecified ethernet address
+ // UnspecifiedEthernetAddress is the unspecified ethernet address
// (all bits set to 0).
- unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
// EthernetBroadcastAddress is an ethernet address that addresses every node
// on a local link.
@@ -134,7 +134,7 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
return false
}
- if addr == unspecifiedEthernetAddress {
+ if addr == UnspecifiedEthernetAddress {
return false
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index bf9ccbf1a..adc04e855 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -44,7 +44,7 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
@@ -91,7 +91,7 @@ func TestIsMulticastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
diff --git a/pkg/tcpip/internal/tcp/BUILD b/pkg/tcpip/internal/tcp/BUILD
new file mode 100644
index 000000000..9ae258a0b
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tcp",
+ srcs = ["tcp.go"],
+ visibility = ["//pkg/tcpip:__subpackages__"],
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/internal/tcp/tcp.go b/pkg/tcpip/internal/tcp/tcp.go
new file mode 100644
index 000000000..0616d368c
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/tcp.go
@@ -0,0 +1,48 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains internal type definitions that are not expected to be
+// used by anyone else outside pkg/tcpip.
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// TSOffset is an offset applied to the value of the TSVal field in the TCP
+// Timestamp option.
+//
+// +stateify savable
+type TSOffset struct {
+ milliseconds uint32
+}
+
+// NewTSOffset creates a new TSOffset from milliseconds.
+func NewTSOffset(milliseconds uint32) TSOffset {
+ return TSOffset{
+ milliseconds: milliseconds,
+ }
+}
+
+// TSVal applies the offset to now and returns the timestamp in milliseconds.
+func (offset TSOffset) TSVal(now tcpip.MonotonicTime) uint32 {
+ return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + offset.milliseconds
+}
+
+// Elapsed calculates the elapsed time given now and the echoed back timestamp.
+func (offset TSOffset) Elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return time.Duration(offset.TSVal(now)-tsEcr) * time.Millisecond
+}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index f26c857eb..658557d62 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -28,7 +28,9 @@ import (
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Pkt *stack.PacketBuffer
+ Pkt *stack.PacketBuffer
+
+ // TODO(https://gvisor.dev/issue/6537): Remove these fields.
Proto tcpip.NetworkProtocolNumber
Route stack.RouteInfo
}
@@ -244,7 +246,10 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
Route: r,
}
- e.q.Write(p)
+ // Write returns false if the queue is full. A full queue is not an error
+ // from the perspective of a LinkEndpoint so we ignore Write's return
+ // value and always return nil from this method.
+ _ = e.q.Write(p)
return nil
}
@@ -290,3 +295,18 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ p := PacketInfo{
+ Pkt: pkt,
+ Proto: pkt.NetworkProtocolNumber,
+ }
+
+ // Write returns false if the queue is full. A full queue is not an error
+ // from the perspective of a LinkEndpoint so we ignore Write's return
+ // value and always return nil from this method.
+ _ = e.q.Write(p)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index b427c6170..8211a2031 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -42,6 +42,14 @@ type Endpoint struct {
nested.Endpoint
}
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ if l := e.Endpoint.LinkAddress(); len(l) != 0 {
+ return l
+ }
+ return header.UnspecifiedEthernetAddress
+}
+
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
@@ -57,18 +65,22 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP
// Capabilities implements stack.LinkEndpoint.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+ c := e.Endpoint.Capabilities()
+ if c&stack.CapabilityLoopback == 0 {
+ c |= stack.CapabilityResolutionRequired
+ }
+ return c
}
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
return e.Endpoint.WritePacket(r, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- linkAddr := e.Endpoint.LinkAddress()
+ linkAddr := e.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
@@ -83,7 +95,10 @@ func (e *Endpoint) MaxHeaderLength() uint16 {
}
// ARPHardwareType implements stack.LinkEndpoint.
-func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ if a := e.Endpoint.ARPHardwareType(); a != header.ARPHardwareNone {
+ return a
+ }
return header.ARPHardwareEther
}
@@ -97,3 +112,8 @@ func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkP
}
eth.Encode(&fields)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.Endpoint.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 48356c343..058242f96 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -505,6 +505,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
}
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 7012d8829..ca1f9c08d 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,19 +76,8 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- // Construct data as the unparsed portion for the loopback packet.
- data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
-
- // Because we're immediately turning around and writing the packet back
- // to the rx path, we intentionally don't preserve the remote and local
- // link addresses from the stack.Route we're passed.
- newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: data,
- })
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt)
-
- return nil
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ return e.WriteRawPacket(pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -103,3 +92,19 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ // Construct data as the unparsed portion for the loopback packet.
+ data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+
+ // Because we're immediately turning around and writing the packet back
+ // to the rx path, we intentionally don't preserve the remote and local
+ // link addresses from the stack.Route we're passed.
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: data,
+ })
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 3e2a1aa94..844f5959b 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -131,6 +131,11 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*InjectableEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 3e816b0c7..83a6c1cc8 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -60,16 +60,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.mu.RLock()
- d := e.dispatcher
- e.mu.RUnlock()
- if d != nil {
- d.DeliverOutboundPacket(remote, local, protocol, pkt)
- }
-}
-
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -152,3 +142,8 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.child.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.child.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
deleted file mode 100644
index 6fff160ce..000000000
--- a/pkg/tcpip/link/packetsocket/BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "packetsocket",
- srcs = ["endpoint.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/link/nested",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
deleted file mode 100644
index e01837e2d..000000000
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package packetsocket provides a link layer endpoint that provides the ability
-// to loop outbound packets to any AF_PACKET sockets that may be interested in
-// the outgoing packet.
-package packetsocket
-
-import (
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/link/nested"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type endpoint struct {
- nested.Endpoint
-}
-
-// New creates a new packetsocket LinkEndpoint.
-func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- e := &endpoint{}
- e.Endpoint.Init(lower, e)
- return e
-}
-
-// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
- return e.Endpoint.WritePacket(r, protocol, pkt)
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
- }
-
- return e.Endpoint.WritePackets(r, pkts, proto)
-}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 5030b6ba1..3ed0aa3fe 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -121,3 +121,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.
func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 40bd5560b..b41e3e2fa 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -108,11 +108,6 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
-}
-
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// nil means the NIC is being removed.
@@ -228,3 +223,8 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.lower.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
index 298bad55d..f2c230720 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVQ $0x0, R10 // sigmask parameter which isn't used here
MOVQ $0x10f, AX // SYS_PPOLL
SYSCALL
- CMPQ AX, $0xfffffffffffff001
+ CMPQ AX, $0xfffffffffffff002
JLS ok
MOVQ $-1, n+24(FP)
NEGQ AX
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
index b62888b93..8807586c7 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVD $0x0, R3 // sigmask parameter which isn't used here
MOVD $0x49, R8 // SYS_PPOLL
SVC
- CMP $0xfffffffffffff001, R0
+ CMP $0xfffffffffffff002, R0
BLS ok
MOVD $-1, R1
MOVD R1, n+24(FP)
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index e76fc55b6..e53789d92 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -152,10 +152,22 @@ type PollEvent struct {
// no data is available, it will block in a poll() syscall until the file
// descriptor becomes readable.
func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
+ n, err := BlockingReadUntranslated(fd, b)
+ if err != 0 {
+ return n, TranslateErrno(err)
+ }
+ return n, nil
+}
+
+// BlockingReadUntranslated reads from a file descriptor that is set up as
+// non-blocking. If no data is available, it will block in a poll() syscall
+// until the file descriptor becomes readable. It returns the raw unix.Errno
+// value returned by the underlying syscalls.
+func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
- return int(n), nil
+ return int(n), 0
}
event := PollEvent{
@@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
_, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
- return 0, TranslateErrno(e)
+ return 0, e
}
}
}
@@ -181,7 +193,9 @@ func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip
if e == 0 {
return int(n), nil
}
-
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
+ }
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, nil
@@ -204,6 +218,10 @@ func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpi
return int(n), nil
}
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
+ }
+
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, nil
@@ -228,5 +246,13 @@ func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno)
},
}
_, errno := BlockingPoll(&pevents[0], len(pevents), nil)
+ if errno != 0 {
+ return pevents[0].Revents&unix.POLLIN != 0, errno
+ }
+
+ if pevents[1].Revents&unix.POLLHUP != 0 || pevents[1].Revents&unix.POLLERR != 0 {
+ errno = unix.ECONNRESET
+ }
+
return pevents[0].Revents&unix.POLLIN != 0, errno
}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 4215ee852..f8076d83c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -5,19 +5,26 @@ package(licenses = ["notice"])
go_library(
name = "sharedmem",
srcs = [
+ "queuepair.go",
"rx.go",
+ "server_rx.go",
+ "server_tx.go",
"sharedmem.go",
+ "sharedmem_server.go",
"sharedmem_unsafe.go",
"tx.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/cleanup",
+ "//pkg/eventfd",
"//pkg/log",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/pipe",
"//pkg/tcpip/link/sharedmem/queue",
"//pkg/tcpip/stack",
"@org_golang_x_sys//unix:go_default_library",
@@ -26,9 +33,7 @@ go_library(
go_test(
name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
+ srcs = ["sharedmem_test.go"],
library = ":sharedmem",
deps = [
"//pkg/sync",
@@ -41,3 +46,22 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "sharedmem_server_test",
+ size = "small",
+ srcs = ["sharedmem_server_test.go"],
+ deps = [
+ ":sharedmem",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
index 696e6c9e5..a78826ebc 100644
--- a/pkg/tcpip/link/sharedmem/queue/rx.go
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
}
r.tx.Flush()
-
return true
}
@@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
for {
outBufs := bufs
-
// Pull the next descriptor from the rx pipe.
b := r.rx.Pull()
if b == nil {
diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go
new file mode 100644
index 000000000..b12647fdd
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queuepair.go
@@ -0,0 +1,199 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
+)
+
+const (
+ // defaultQueueDataSize is the size of the shared memory data region that
+ // holds the scatter/gather buffers.
+ defaultQueueDataSize = 1 << 20 // 1MiB
+
+ // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors.
+ //
+ // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU)
+ // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data
+ // area. Which means the pipe needs to be big enough to hold 819
+ // descriptors.
+ //
+ // Each descriptor is approximately 8 (slot descriptor in pipe) +
+ // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is
+ // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.)
+ //
+ // Which means we need approximately 36*819 ~ 29 KiB to store all packet
+ // descriptors. We could go with a 32 KiB pipe but to give it some slack in
+ // how the upper layer may make use of the scatter gather buffers we double
+ // this to hold enough descriptors.
+ defaultQueuePipeSize = 64 << 10 // 64KiB
+
+ // defaultSharedDataSize is the size of the sharedData region used to
+ // enable/disable notifications.
+ defaultSharedDataSize = 4 << 10 // 4KiB
+)
+
+// A QueuePair represents a pair of TX/RX queues.
+type QueuePair struct {
+ // txCfg is the QueueConfig to be used for transmit queue.
+ txCfg QueueConfig
+
+ // rxCfg is the QueueConfig to be used for receive queue.
+ rxCfg QueueConfig
+}
+
+// NewQueuePair creates a shared memory QueuePair.
+func NewQueuePair() (*QueuePair, error) {
+ txCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to create tx queue: %s", err)
+ }
+
+ rxCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ closeFDs(txCfg)
+ return nil, fmt.Errorf("failed to create rx queue: %s", err)
+ }
+
+ return &QueuePair{
+ txCfg: txCfg,
+ rxCfg: rxCfg,
+ }, nil
+}
+
+// Close closes underlying tx/rx queue fds.
+func (q *QueuePair) Close() {
+ closeFDs(q.txCfg)
+ closeFDs(q.rxCfg)
+}
+
+// TXQueueConfig returns the QueueConfig for the receive queue.
+func (q *QueuePair) TXQueueConfig() QueueConfig {
+ return q.txCfg
+}
+
+// RXQueueConfig returns the QueueConfig for the transmit queue.
+func (q *QueuePair) RXQueueConfig() QueueConfig {
+ return q.rxCfg
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(s queueSizes) (QueueConfig, error) {
+ success := false
+ var eventFD eventfd.Eventfd
+ var dataFD, txPipeFD, rxPipeFD, sharedDataFD int
+ defer func() {
+ if success {
+ return
+ }
+ closeFDs(QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ })
+ }()
+ eventFD, err := eventfd.Create()
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err)
+ }
+ dataFD, err = createFile(s.dataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err)
+ }
+ txPipeFD, err = createFile(s.txPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err)
+ }
+ rxPipeFD, err = createFile(s.rxPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err)
+ }
+ sharedDataFD, err = createFile(s.sharedDataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err)
+ }
+ success = true
+ return QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ }, nil
+}
+
+func createFile(size int64, initQueue bool) (fd int, err error) {
+ const tmpDir = "/dev/shm/"
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ return -1, fmt.Errorf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ unix.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil {
+ return -1, fmt.Errorf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err = unix.Dup(int(f.Fd()))
+ if err != nil {
+ return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err)
+ }
+
+ if err := unix.Ftruncate(fd, size); err != nil {
+ unix.Close(fd)
+ return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err)
+ }
+
+ return fd, nil
+}
+
+func closeFDs(c QueueConfig) {
+ unix.Close(c.DataFD)
+ c.EventFD.Close()
+ unix.Close(c.TxPipeFD)
+ unix.Close(c.RxPipeFD)
+ unix.Close(c.SharedDataFD)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index e882a128c..87747dcc7 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -21,7 +21,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -30,7 +30,7 @@ type rx struct {
data []byte
sharedData []byte
q queue.Rx
- eventFD int
+ eventFD eventfd.Eventfd
}
// init initializes all state needed by the rx queue based on the information
@@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
// Duplicate the eventFD so that caller can close it but we can still
// use it.
- efd, err := unix.Dup(c.EventFD)
+ efd, err := c.EventFD.Dup()
if err != nil {
unix.Munmap(txPipe)
unix.Munmap(rxPipe)
@@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
return err
}
- // Set the eventfd as non-blocking.
- if err := unix.SetNonblock(efd, true); err != nil {
- unix.Munmap(txPipe)
- unix.Munmap(rxPipe)
- unix.Munmap(data)
- unix.Munmap(sharedData)
- unix.Close(efd)
- return err
- }
-
// Initialize state based on buffers.
r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData))
r.data = data
@@ -105,7 +95,13 @@ func (r *rx) cleanup() {
unix.Munmap(r.data)
unix.Munmap(r.sharedData)
- unix.Close(r.eventFD)
+ r.eventFD.Close()
+}
+
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (r *rx) notify() {
+ r.eventFD.Notify()
}
// postAndReceive posts the provided buffers (if any), and then tries to read
@@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
if len(b) != 0 && !r.q.PostBuffers(b) {
r.q.EnableNotification()
for !r.q.PostBuffers(b) {
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
@@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
}
// Wait for notification.
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go
new file mode 100644
index 000000000..6ea21ffd1
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_rx.go
@@ -0,0 +1,142 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+type serverRx struct {
+ // packetPipe represents the receive end of the pipe that carries the packet
+ // descriptors sent by the client.
+ packetPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that will carry
+ // completion notifications from the server to the client.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when transmission is completed.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all state needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverRx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ packetPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ s.packetPipe.Init(packetPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ cu.Release()
+ return nil
+}
+
+func (s *serverRx) cleanup() {
+ unix.Munmap(s.packetPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// completionNotificationSize is size in bytes of a completion notification sent
+// on the completion queue after a transmitted packet has been handled.
+const completionNotificationSize = 8
+
+// receive receives a single packet from the packetPipe.
+func (s *serverRx) receive() []byte {
+ desc := s.packetPipe.Pull()
+ if desc == nil {
+ return nil
+ }
+
+ pktInfo := queue.DecodeTxPacketHeader(desc)
+ contents := make([]byte, 0, pktInfo.Size)
+ toCopy := pktInfo.Size
+ for i := 0; i < pktInfo.BufferCount; i++ {
+ txBuf := queue.DecodeTxBufferHeader(desc, i)
+ if txBuf.Size <= toCopy {
+ contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...)
+ toCopy -= txBuf.Size
+ continue
+ }
+ contents = append(contents, s.data[txBuf.Offset:][:toCopy]...)
+ break
+ }
+
+ // Flush to let peer know that slots queued for transmission have been handled
+ // and its free to reuse the slots.
+ s.packetPipe.Flush()
+ // Encode packet completion.
+ b := s.completionPipe.Push(completionNotificationSize)
+ queue.EncodeTxCompletion(b, pktInfo.ID)
+ s.completionPipe.Flush()
+ return contents
+}
+
+func (s *serverRx) waitForPackets() {
+ s.eventFD.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go
new file mode 100644
index 000000000..13a82903f
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_tx.go
@@ -0,0 +1,175 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// serverTx represents the server end of the sharedmem queue and is used to send
+// packets to the peer in the buffers posted by the peer in the fillPipe.
+type serverTx struct {
+ // fillPipe represents the receive end of the pipe that carries the RxBuffers
+ // posted by the peer.
+ fillPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that carries the
+ // descriptors for filled RxBuffers.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when fill requests are fulfilled.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all tstate needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverTx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ fillPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ cu.Release()
+
+ s.fillPipe.Init(fillPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ return nil
+}
+
+func (s *serverTx) cleanup() {
+ unix.Munmap(s.fillPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// fillPacket copies the data in the provided views into buffers pulled from the
+// fillPipe and returns a slice of RxBuffers that contain the copied data as
+// well as the total number of bytes copied.
+//
+// To avoid allocations the filledBuffers are appended to the buffers slice
+// which will be grown as required.
+func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) {
+ filledBuffers = buffers[:0]
+ // fillBuffer copies as much of the views as possible into the provided buffer
+ // and returns any left over views (if any).
+ fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) {
+ if len(views) == 0 {
+ return nil
+ }
+ availBytes := buffer.Size
+ copied := uint64(0)
+ for availBytes > 0 && len(views) > 0 {
+ n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0])
+ views[0].TrimFront(n)
+ if !views[0].IsEmpty() {
+ break
+ }
+ views = views[1:]
+ copied += uint64(n)
+ availBytes -= uint32(n)
+ }
+ buffer.Size = uint32(copied)
+ return views
+ }
+
+ for len(views) > 0 {
+ var b []byte
+ // Spin till we get a free buffer reposted by the peer.
+ for {
+ if b = s.fillPipe.Pull(); b != nil {
+ break
+ }
+ }
+ rxBuffer := queue.DecodeRxBufferHeader(b)
+ // Copy the packet into the posted buffer.
+ views = fillBuffer(&rxBuffer, views)
+ totalCopied += rxBuffer.Size
+ filledBuffers = append(filledBuffers, rxBuffer)
+ }
+
+ return filledBuffers, totalCopied
+}
+
+func (s *serverTx) transmit(views []buffer.View) bool {
+ buffers := make([]queue.RxBuffer, 8)
+ buffers, totalCopied := s.fillPacket(views, buffers)
+ b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers)))
+ if b == nil {
+ return false
+ }
+ queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */)
+ for i := 0; i < len(buffers); i++ {
+ queue.EncodeRxCompletionBuffer(b, i, buffers[i])
+ }
+ s.completionPipe.Flush()
+ s.fillPipe.Flush()
+ return true
+}
+
+func (s *serverTx) notify() {
+ s.eventFD.Notify()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 30cf659b8..bcb37a465 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -24,14 +24,16 @@
package sharedmem
import (
+ "fmt"
"sync/atomic"
- "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,7 +49,7 @@ type QueueConfig struct {
// EventFD is a file descriptor for the event that is signaled when
// data is becomes available in this queue.
- EventFD int
+ EventFD eventfd.Eventfd
// TxPipeFD is a file descriptor for the tx pipe associated with the
// queue.
@@ -63,16 +65,97 @@ type QueueConfig struct {
SharedDataFD int
}
+// FDs returns the FD's in the QueueConfig as a slice of ints. This must
+// be used in conjunction with QueueConfigFromFDs to ensure the order
+// of FDs matches when reconstructing the config when serialized or sent
+// as part of control messages.
+func (q *QueueConfig) FDs() []int {
+ return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD}
+}
+
+// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each
+// entry represents an file descriptor. The order of FDs in the slice must be in
+// the order specified below for the config to be valid. QueueConfig.FDs()
+// should be used when the config needs to be serialized or sent as part of a
+// control message to ensure the correct order.
+func QueueConfigFromFDs(fds []int) (QueueConfig, error) {
+ if len(fds) != 5 {
+ return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds))
+ }
+ return QueueConfig{
+ DataFD: fds[0],
+ EventFD: eventfd.Wrap(fds[1]),
+ TxPipeFD: fds[2],
+ RxPipeFD: fds[3],
+ SharedDataFD: fds[4],
+ }, nil
+}
+
+// Options specify the details about the sharedmem endpoint to be created.
+type Options struct {
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // BufferSize is the size of each scatter/gather buffer that will hold packet
+ // data.
+ //
+ // NOTE: This directly determines number of packets that can be held in
+ // the ring buffer at any time. This does not have to be sized to the MTU as
+ // the shared memory queue design allows usage of more than one buffer to be
+ // used to make up a given packet.
+ BufferSize uint32
+
+ // LinkAddress is the link address for this endpoint (required).
+ LinkAddress tcpip.LinkAddress
+
+ // TX is the transmit queue configuration for this shared memory endpoint.
+ TX QueueConfig
+
+ // RX is the receive queue configuration for this shared memory endpoint.
+ RX QueueConfig
+
+ // PeerFD is the fd for the connected peer which can be used to detect
+ // peer disconnects.
+ PeerFD int
+
+ // OnClosed is a function that is called when the endpoint is being closed
+ // (probably due to peer going away)
+ OnClosed func(err tcpip.Error)
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
+}
+
type endpoint struct {
// mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
mtu uint32
// bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
bufferSize uint32
// addr is the local address of this endpoint.
+ // addr is immutable.
addr tcpip.LinkAddress
+ // peerFD is an fd to the peer that can be used to detect when the
+ // peer is gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
// rx is the receive queue.
rx rx
@@ -83,34 +166,55 @@ type endpoint struct {
// Wait group used to indicate that all workers have stopped.
completed sync.WaitGroup
+ // onClosed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ onClosed func(tcpip.Error)
+
// mu protects the following fields.
mu sync.Mutex
// tx is the transmit queue.
+ // +checklocks:mu
tx tx
// workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
workerStarted bool
}
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
+func New(opts Options) (stack.LinkEndpoint, error) {
e := &endpoint{
- mtu: mtu,
- bufferSize: bufferSize,
- addr: addr,
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
}
- if err := e.tx.init(bufferSize, &tx); err != nil {
+ if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil {
return nil, err
}
- if err := e.rx.init(bufferSize, &rx); err != nil {
+ if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil {
e.tx.cleanup()
return nil, err
}
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
return e, nil
}
@@ -119,13 +223,13 @@ func (e *endpoint) Close() {
// Tell dispatch goroutine to stop, then write to the eventfd so that
// it wakes up in case it's sleeping.
atomic.StoreUint32(&e.stopRequested, 1)
- unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ e.rx.eventFD.Notify()
// Cleanup the queues inline if the worker hasn't started yet; we also
// know it won't start from now on because stopRequested is set to 1.
e.mu.Lock()
+ defer e.mu.Unlock()
workerPresent := e.workerStarted
- e.mu.Unlock()
if !workerPresent {
e.tx.cleanup()
@@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
e.workerStarted = true
e.completed.Add(1)
+
+ // Spin up a goroutine to monitor for peer shutdown.
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ go func() {
+ defer e.completed.Done()
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any data
+ // transfer and this Read should only return if the peer is shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ }()
+ }
+
// Link endpoints are not savable. When transportation endpoints
// are saved, they stop sending outgoing packets and all
// incoming packets are rejected.
@@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool {
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *endpoint) MTU() uint32 {
- return e.mtu - header.EthernetMinimumSize
+ return e.mtu - e.hdrSize
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
}
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
// ethernet frame header size.
-func (*endpoint) MaxHeaderLength() uint16 {
- return header.EthernetMinimumSize
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
@@ -202,17 +322,18 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
eth.Encode(ethHdr)
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
+// +checklocks:e.mu
+func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
views := pkt.Views()
// Transmit the packet.
- e.mu.Lock()
ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
if !ok {
return &tcpip.ErrWouldBlock{}
}
@@ -220,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
return nil
}
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
+func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
}
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
@@ -265,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
Data: buffer.View(b).ToVectorisedView(),
})
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- continue
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
}
- eth := header.Ethernet(hdr)
// Send packet up the stack.
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
}
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
// Clean state.
e.tx.cleanup()
e.rx.cleanup()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go
new file mode 100644
index 000000000..ccc84989d
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go
@@ -0,0 +1,333 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type serverEndpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ // addr is immutable
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx serverRx
+
+ // stopRequested is to be accessed atomically only, and determines if the
+ // worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // peerFD is an fd to the peer that can be used to detect when the peer is
+ // gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
+ // onClosed is a function to be called when the FD's peer (if any) closes its
+ // end of the communication pipe.
+ onClosed func(tcpip.Error)
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ // +checklocks:mu
+ tx serverTx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
+ workerStarted bool
+}
+
+// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be
+// broken up into buffers of "bufferSize" bytes.
+func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) {
+ e := &serverEndpoint{
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
+ }
+
+ if err := e.tx.init(&opts.RX); err != nil {
+ return nil, err
+ }
+
+ if err := e.rx.init(&opts.TX); err != nil {
+ e.tx.cleanup()
+ return nil, err
+ }
+
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
+
+ return e, nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *serverEndpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes
+ // up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ e.rx.eventFD.Notify()
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also know it
+ // won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ workerPresent := e.workerStarted
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
+func (e *serverEndpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ // Spin up a goroutine to monitor for peer shutdown.
+ go func() {
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any
+ // data transfer and this Read should only return if the peer is
+ // shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ e.completed.Done()
+ }()
+ }
+ // Link endpoints are not savable. When transportation endpoints are saved,
+ // they stop sending outgoing packets and all incoming packets are rejected.
+ go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
+ }
+ e.mu.Unlock()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *serverEndpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *serverEndpoint) MTU() uint32 {
+ return e.mtu - e.hdrSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (e *serverEndpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: remote,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*serverEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// +checklocks:e.mu
+func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+
+ views := pkt.Views()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+
+ return nil
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ // Transmit the packet.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *serverEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ b := e.rx.receive()
+ if b == nil {
+ e.rx.waitForPackets()
+ continue
+ }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
+ }
+ // Send packet up the stack.
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
new file mode 100644
index 000000000..1bc58614e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem_server_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+ localIPv4Address = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02")
+ serverPort = 10001
+
+ defaultMTU = 1500
+ defaultBufferSize = 1500
+)
+
+type stackOptions struct {
+ ep stack.LinkEndpoint
+ addr tcpip.Address
+}
+
+func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) {
+ st := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ })
+ nicID := tcpip.NICID(1)
+ sniffEP := sniffer.New(stackOpts.ep)
+ opts := stack.NICOptions{Name: "eth0"}
+ if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil {
+ return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err)
+ }
+
+ // Add Protocol Address.
+ protocolNum := ipv4.ProtocolNumber
+ routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}
+ if len(stackOpts.addr) == 16 {
+ routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}}
+ protocolNum = ipv6.ProtocolNumber
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: protocolNum,
+ AddressWithPrefix: stackOpts.addr.WithPrefix(),
+ }
+ if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ // Setup route table.
+ st.SetRouteTable(routeTable)
+
+ return st, nil
+}
+
+func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.New(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: localLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: remoteLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+type testContext struct {
+ clientStk *stack.Stack
+ serverStk *stack.Stack
+ peerFDs [2]int
+}
+
+func newTestContext(t *testing.T) *testContext {
+ peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ t.Fatalf("failed to create peerFDs: %s", err)
+ }
+ q, err := sharedmem.NewQueuePair()
+ if err != nil {
+ t.Fatalf("failed to create sharedmem queue: %s", err)
+ }
+ clientStack, err := newClientStack(t, q, peerFDs[0])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ t.Fatalf("failed to create client stack: %s", err)
+ }
+ serverStack, err := newServerStack(t, q, peerFDs[1])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ clientStack.Close()
+ t.Fatalf("failed to create server stack: %s", err)
+ }
+ return &testContext{
+ clientStk: clientStack,
+ serverStk: serverStack,
+ peerFDs: peerFDs,
+ }
+}
+
+func (ctx *testContext) cleanup() {
+ unix.Close(ctx.peerFDs[0])
+ unix.Close(ctx.peerFDs[1])
+ ctx.clientStk.Close()
+ ctx.serverStk.Close()
+}
+
+func TestServerRoundTrip(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+ listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
+ l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("failed to start TCP Listener: %s", err)
+ }
+ defer l.Close()
+ var responseString = "response"
+ go func() {
+ http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(responseString))
+ }))
+ }()
+
+ dialFunc := func(address, protocol string) (net.Conn, error) {
+ return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if got, want := string(body), responseString; got != want {
+ t.Fatalf("unexpected response got: %s, want: %s", got, want)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index d6d953085..66ffc33b8 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -19,9 +19,7 @@ package sharedmem
import (
"bytes"
- "io/ioutil"
"math/rand"
- "os"
"strings"
"testing"
"time"
@@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
t: t,
packetCh: make(chan struct{}, 1000000),
}
- c.txCfg = createQueueFDs(t, queueSizes{
+ c.txCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
-
- c.rxCfg = createQueueFDs(t, queueSizes{
+ if err != nil {
+ t.Fatalf("createQueueFDs for tx failed: %s", err)
+ }
+ c.rxCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
+ if err != nil {
+ t.Fatalf("createQueueFDs for rx failed: %s", err)
+ }
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(Options{
+ MTU: mtu,
+ BufferSize: bufferSize,
+ LinkAddress: addr,
+ TX: c.txCfg,
+ RX: c.rxCfg,
+ PeerFD: -1,
+ })
if err != nil {
t.Fatalf("New failed: %v", err)
}
@@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.
func (c *testContext) cleanup() {
c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
+ closeFDs(c.txCfg)
+ closeFDs(c.rxCfg)
c.txq.cleanup()
c.rxq.cleanup()
}
@@ -191,69 +201,6 @@ func shuffle(b []int) {
}
}
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
- if !ok {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- unix.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := unix.Ftruncate(fd, size); err != nil {
- unix.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- unix.Close(c.DataFD)
- unix.Close(c.EventFD)
- unix.Close(c.TxPipeFD)
- unix.Close(c.RxPipeFD)
- unix.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
// TestSimpleSend sends 1000 packets with random header and payload sizes,
// then checks that the right payload is received on the shared memory queues.
func TestSimpleSend(t *testing.T) {
@@ -672,7 +619,7 @@ func TestSimpleReceive(t *testing.T) {
// Push completion.
c.pushRxCompletion(uint32(len(contents)), bufs)
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
@@ -718,7 +665,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete the buffer.
c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for it to be reposted.
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
@@ -734,7 +681,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete with two buffers.
c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for them to be reposted.
for j := 0; j < 2; j++ {
@@ -759,7 +706,7 @@ func TestReceivePostingIsFull(t *testing.T) {
first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that packet is received.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
@@ -768,7 +715,7 @@ func TestReceivePostingIsFull(t *testing.T) {
second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that no packet is received yet, as the worker is blocked trying
// to repost.
@@ -781,7 +728,7 @@ func TestReceivePostingIsFull(t *testing.T) {
// Flush tx queue, which will allow the first buffer to be reposted,
// and the second completion to be pulled.
c.rxq.tx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that second packet completes.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
@@ -803,7 +750,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) {
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be indicated.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index e3210051f..35e5bff12 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -18,6 +18,7 @@ import (
"math"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -28,10 +29,12 @@ const (
// tx holds all state associated with a tx queue.
type tx struct {
- data []byte
- q queue.Tx
- ids idManager
- bufs bufferManager
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+ eventFD eventfd.Eventfd
+ sharedDataFD int
}
// init initializes all state needed by the tx queue based on the information
@@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error {
t.ids.init()
t.bufs.init(0, len(data), int(mtu))
t.data = data
-
+ t.eventFD = c.EventFD
+ t.sharedDataFD = c.SharedDataFD
return nil
}
@@ -142,6 +146,12 @@ func (t *tx) transmit(bufs ...buffer.View) bool {
return true
}
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (t *tx) notify() {
+ t.eventFD.Notify()
+}
+
// getBuffer returns a memory region mapped to the full contents of the given
// file descriptor.
func getBuffer(fd int) ([]byte, error) {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 28a172e71..2afa95af0 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -140,11 +140,6 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
-}
-
func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 4758a99ad..c3e4c3455 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/refs",
"//pkg/refsvfs2",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index d23210503..fa2131c28 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -174,7 +173,7 @@ func (d *Device) Write(data []byte) (int64, error) {
return 0, linuxerr.EBADFD
}
if !endpoint.IsAttached() {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
dataLen := int64(len(data))
@@ -249,7 +248,7 @@ func (d *Device) Read() ([]byte, error) {
for {
info, ok := endpoint.Read()
if !ok {
- return nil, syserror.ErrWouldBlock
+ return nil, linuxerr.ErrWouldBlock
}
v, ok := d.encodePkt(&info)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a95602aa5..116e4defb 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -59,15 +59,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- if !e.dispatchGate.Enter() {
- return
- }
- e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
- e.dispatchGate.Leave()
-}
-
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -155,3 +146,6 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index a71400ee9..b0e4237bd 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -80,6 +80,11 @@ func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBuffe
return pkts.Len(), nil
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 7b1ff44f4..c0179104a 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -23,8 +23,10 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6515c31e5..e08243547 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -272,7 +272,6 @@ type protocol struct {
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
-func (p *protocol) DefaultPrefixLen() int { return 0 }
func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 5fcbfeaa2..061cc35ae 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext
t.Fatalf("CreateNIC failed: %s", err)
}
- if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %s", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: stackAddr.WithPrefix(),
+ }
+ if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
tc.s.SetRouteTable([]tcpip.Route{{
@@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index 605e9ef8d..4d4d98caf 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -101,6 +101,11 @@ func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return heade
func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
// how many random bytes will be copied in the Transport Header.
// extraHeaderReserveLength indicates how much extra space will be reserved for
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 771b9173a..87f650661 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "bytes"
"fmt"
"strings"
"testing"
@@ -32,8 +33,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const nicID = 1
@@ -230,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv4.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
@@ -246,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv6.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
@@ -269,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c
}
v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
}
v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
}
return s, e
@@ -710,8 +725,8 @@ func TestReceive(t *testing.T) {
if !ok {
t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
}
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
} else {
ep.DecRef()
}
@@ -882,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -968,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1234,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1301,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name string
protoFactory stack.NetworkProtocolFactory
protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.Address
+ nicAddr tcpip.AddressWithPrefix
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
@@ -1311,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1352,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with IHL too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1376,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1394,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 minimum size",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1430,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
@@ -1475,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options and data across views",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
@@ -1516,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
@@ -1556,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 with extension header",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
@@ -1601,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 minimum size",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1636,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 too small",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1660,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}{
{
name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))),
},
{
name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))),
},
}
@@ -1677,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.protoNum,
+ AddressWithPrefix: test.nicAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
- r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
}
defer r.Release()
@@ -2032,3 +2051,97 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
})
}
}
+
+func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.AddressWithPrefix
+ payloadOffset int
+ }{
+ {
+ name: "IPv4",
+ proto: header.IPv4ProtocolNumber,
+ addr: localIPv4AddrWithPrefix,
+ payloadOffset: header.IPv4MinimumSize,
+ },
+ {
+ name: "IPv6",
+ proto: header.IPv6ProtocolNumber,
+ addr: localIPv6AddrWithPrefix,
+ payloadOffset: 0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ RawFactory: raw.EndpointFactory{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.proto,
+ AddressWithPrefix: test.addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: test.addr.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
+ }
+ defer ep.Close()
+
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.addr.Address,
+ },
+ }
+ data := []byte{1, 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := ep.Write(&r, writeOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
+ }
+
+ // Wait for the endpoint to become readable.
+ <-ch
+
+ var w bytes.Buffer
+ rr, err := ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if err != nil {
+ t.Fatalf("ep.Read(...): %s", err)
+ }
+ if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" {
+ t.Errorf("remote addr mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 2aa38eb98..1c3b0887f 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -167,14 +167,17 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
p := hdr.TransportProtocol()
dstAddr := hdr.DestinationAddress()
// Skip the ip header, then deliver the error.
- pkt.Data().DeleteFront(hlen)
+ if _, ok := pkt.Data().Consume(hlen); !ok {
+ panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen))
+ }
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
// ICMP packets don't have their TransportHeader fields set. See
- // icmp/protocol.go:protocol.Parse for a full explanation.
+ // icmp/protocol.go:protocol.Parse for a full explanation. Not all ICMP types
+ // require consuming the header, so we only call PullUp.
v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
@@ -240,15 +243,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.echoRequest.Increment()
- sent := e.stats.icmp.packetsSent
- if !e.protocol.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return
- }
-
// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
- // be modifying fields.
+ // be modifying fields. Both the ICMP header (with its type modified to
+ // EchoReply) and payload are reused in the reply packet.
//
// TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
// waiting endpoints. Consider moving responsibility for doing the copy to
@@ -281,6 +279,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
defer r.Release()
+ sent := e.stats.icmp.packetsSent
+ if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) {
+ sent.rateLimited.Increment()
+ return
+ }
+
// TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
// header information, we may have to change this code to handle the
// ICMP header no longer being in the data buffer.
@@ -331,6 +335,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.echoReply.Increment()
+ // ICMP sockets expect the ICMP header to be present, so we don't consume
+ // the ICMP header.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
@@ -338,7 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
mtu := h.MTU()
code := h.Code()
- pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
+ if _, ok := pkt.Data().Consume(header.ICMPv4MinimumSize); !ok {
+ panic("could not consume ICMPv4MinimumSize bytes")
+ }
switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
@@ -562,13 +570,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
@@ -606,6 +607,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) {
+ switch reason := reason.(type) {
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonProtoUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetworkUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonFragmentationNeeded:
+ return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0
+ case *icmpReasonTTLExceeded:
+ return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0
+ case *icmpReasonParamProblem:
+ return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType, icmpCode) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
// Now work out how much of the triggering packet we should return.
// As per RFC 1812 Section 4.3.2.3
//
@@ -658,44 +688,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonProtoUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetworkUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4NetUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4HostUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonFragmentationNeeded:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
- counter = sent.dstUnreachable
- case *icmpReasonTTLExceeded:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.timeExceeded
- case *icmpReasonParamProblem:
- icmpHdr.SetType(header.ICMPv4ParamProblem)
- icmpHdr.SetCode(header.ICMPv4UnusedCode)
- icmpHdr.SetPointer(reason.pointer)
- counter = sent.paramProblem
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetPointer(pointer)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4bd6f462e..c6576fcbc 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// cycles.
func TestIGMPV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
- addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
@@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) {
// The initial set of IGMP reports that were queued should be sent once an
// address is assigned.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackAddr,
+ PrefixLen: defaultPrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if got := reportStat.Value(); got != 1 {
t.Errorf("got reportStat.Value() = %d, want = 1", got)
@@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e, s, _ := createStack(t, true)
for _, address := range test.stackAddresses {
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: address,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
stats := s.Stats()
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 44c85bdb8..9b71738ae 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -240,7 +247,7 @@ func (e *endpoint) Enable() tcpip.Error {
}
// Create an endpoint to receive broadcast packets on this interface.
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
return err
}
@@ -419,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -459,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -542,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -569,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -710,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -737,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv4(newPkt.NetworkHeader().View())
// As per RFC 791 page 30, Time to Live,
//
@@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// Even if no local information is available on the time actually
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
+ // We perform a full checksum as we may have updated options above. The IP
+ // header is relatively small so this is not expected to be an expensive
+ // operation.
+ newHdr.SetChecksum(0)
+ newHdr.SetChecksum(^newHdr.CalculateChecksum())
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -826,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -856,6 +871,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
}
func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
+ pkt.NICID = e.nic.ID()
+
// Raw socket packets are delivered based solely on the transport protocol
// number. We only require that the packet be valid IPv4, and that they not
// be fragmented.
@@ -863,7 +880,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
- pkt.NICID = e.nic.ID()
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -924,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -1074,11 +1090,11 @@ func (e *endpoint) Close() {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err == nil {
e.mu.igmp.sendQueuedReports()
}
@@ -1199,6 +1215,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv4Type]struct{}
}
// defaultTTL is the current default TTL for the protocol. Only the
@@ -1225,11 +1244,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
-// DefaultPrefixLen returns the IPv4 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv4AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
@@ -1319,6 +1333,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type and code may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool {
+ // Mimic linux and never rate limit for PMTU discovery.
+ // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288
+ if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded {
+ return true
+ }
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
@@ -1398,6 +1429,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
+ // Set ICMP rate limiting to Linux defaults.
+ // See https://man7.org/linux/man-pages/man7/icmp.7.html.
+ p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{
+ header.ICMPv4DstUnreachable: struct{}{},
+ header.ICMPv4SrcQuench: struct{}{},
+ header.ICMPv4TimeExceeded: struct{}{},
+ header.ICMPv4ParamProblem: struct{}{},
+ }
return p
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 73407be67..ef91245d7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if err := ep.Connect(randomAddr); err != nil {
t.Errorf("Connect failed: %v", err)
@@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
expectedEmittedPacketCount := 1
@@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
// Default routes for IPv4 so ICMP can find a route to the remote
@@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
for _, f := range test.fragments {
@@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) {
const (
nicID = 1
- addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1
+ addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2
+ addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3
)
// Build and return a UDP header containing payload.
@@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
)
- if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask(header.IPv4Broadcast)
@@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
@@ -3349,3 +3373,139 @@ func TestCloseLocking(t *testing.T) {
}
}()
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv4ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.UDPProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index f99cbf8f3..f814926a3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -51,6 +51,7 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 94caaae6c..ff23d48e7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().DeleteFront(header.IPv6MinimumSize)
+ if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok {
+ panic("could not consume IPv6MinimumSize bytes")
+ }
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
+ if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok {
+ panic("could not consume IPv6FragmentHeaderSize bytes")
+ }
}
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
@@ -325,7 +329,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.packetTooBig.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ hdr, ok := pkt.Data().Consume(header.ICMPv6PacketTooBigMinimumSize)
if !ok {
received.invalid.Increment()
return
@@ -334,18 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
if err != nil {
networkMTU = 0
}
- pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
received.dstUnreachable.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ hdr, ok := pkt.Data().Consume(header.ICMPv6DstUnreachableMinimumSize)
if !ok {
received.invalid.Increment()
return
}
code := header.ICMPv6(hdr).Code()
- pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
switch code {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
@@ -692,6 +694,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
defer r.Release()
+ if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) {
+ sent.rateLimited.Increment()
+ return
+ }
+
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data().ExtractVV(),
@@ -1174,13 +1181,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
// TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
// Unfortunately at this time ICMP Packets do not have a transport
@@ -1198,6 +1198,33 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) {
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonPacketTooBig:
+ return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0
+ case *icmpReasonHopLimitExceeded:
+ return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
// As per RFC 4443 section 2.4
@@ -1232,40 +1259,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonParameterProblem:
- icmpHdr.SetType(header.ICMPv6ParamProblem)
- icmpHdr.SetCode(reason.code)
- icmpHdr.SetTypeSpecific(reason.pointer)
- counter = sent.paramProblem
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6AddressUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonPacketTooBig:
- icmpHdr.SetType(header.ICMPv6PacketTooBig)
- icmpHdr.SetCode(header.ICMPv6UnusedCode)
- counter = sent.packetTooBig
- case *icmpReasonHopLimitExceeded:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
- counter = sent.timeExceeded
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetTypeSpecific(typeSpecific)
+
dataRange := newPkt.Data().AsRange()
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 7c2a3e56b..03d9f425c 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -225,8 +226,8 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -407,8 +408,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
+ llProtocolAddr0 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
@@ -416,8 +421,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
+ llProtocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr1.WithPrefix(),
+ }
+ if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err)
}
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -690,8 +699,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -883,8 +896,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1065,8 +1082,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1240,8 +1261,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -1411,12 +1436,14 @@ func TestPacketQueing(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
Clock: clock,
})
+ // Make sure ICMP rate limiting doesn't get in our way.
+ s.SetICMPLimit(rate.Inf)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1669,8 +1696,12 @@ func TestCallsToNeighborCache(t *testing.T) {
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
{
@@ -1704,8 +1735,8 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index b1aec5312..600e805f8 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv6(newPkt.NetworkHeader().View())
// As per RFC 8200 section 3,
//
@@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -1097,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -1127,11 +1130,12 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
}
func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
+ pkt.NICID = e.nic.ID()
+
// Raw socket packets are delivered based solely on the transport protocol
// number. We only require that the packet be valid IPv6.
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
- pkt.NICID = e.nic.ID()
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1179,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1533,19 +1537,22 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
- // Calculate the number of octets parsed from data. We want to remove all
- // the data except the unparsed portion located at the end, which its size
- // is extHdr.Buf.Size().
+ // Calculate the number of octets parsed from data. We want to consume all
+ // the data except the unparsed portion located at the end, whose size is
+ // extHdr.Buf.Size().
trim := pkt.Data().Size() - extHdr.Buf.Size()
// For unfragmented packets, extHdr still contains the transport header.
- // Get rid of it.
+ // Consume that too.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
trim += pkt.TransportHeader().View().Size()
- pkt.Data().DeleteFront(trim)
+ if _, ok := pkt.Data().Consume(trim); !ok {
+ stats.MalformedPacketsReceived.Increment()
+ return fmt.Errorf("could not consume %d bytes", trim)
+ }
stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
@@ -1627,12 +1634,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
defer e.mu.Unlock()
- return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+ return e.addAndAcquirePermanentAddressLocked(addr, properties)
}
// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
@@ -1642,8 +1649,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err != nil {
return nil, err
}
@@ -1986,6 +1993,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv6Type]struct{}
}
ids []uint32
@@ -1997,7 +2007,8 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- fragmentation *fragmentation.Fragmentation
+ fragmentation *fragmentation.Fragmentation
+ icmpRateLimiter *stack.ICMPRateLimiter
}
// Number returns the ipv6 protocol number.
@@ -2010,11 +2021,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen returns the IPv6 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
@@ -2086,6 +2092,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -2171,6 +2184,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
@@ -2267,6 +2292,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
p.SetDefaultTTL(DefaultTTL)
+ // Set default ICMP rate limiting to Linux defaults.
+ //
+ // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big)
+ // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt.
+ defaultIcmpTypes := make(map[header.ICMPv6Type]struct{})
+ for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ {
+ switch i {
+ case header.ICMPv6PacketTooBig:
+ // Do not rate limit packet too big by default.
+ default:
+ defaultIcmpTypes[i] = struct{}{}
+ }
+ }
+ p.mu.icmpRateLimitedTypes = defaultIcmpTypes
+
return p
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index d2a23fd4f..e5286081e 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -41,12 +41,12 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
- addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02")
+ addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03")
// Tests use the extension header identifier values as uint8 instead of
// header.IPv6ExtensionHeaderIdentifier.
@@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err)
}
// Should receive a packet destined to the solicited node address of
// addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr3.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err)
}
// Should still receive a packet destined to the solicited node address of
@@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
@@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Add a default route so that a return packet knows where to go.
@@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) {
func TestFragmentReassemblyTimeout(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
)
- if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
@@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
@@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3341,7 +3373,8 @@ func TestForwarding(t *testing.T) {
ipHeaderLength := header.IPv6MinimumSize
icmpHeaderLength := header.ICMPv6MinimumSize
- totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
+ payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen
+ totalLength := ipHeaderLength + payloadLength
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
@@ -3359,7 +3392,7 @@ func TestForwarding(t *testing.T) {
copy(hdr.Prepend(extHdrLen), extHdrBytes)
ip := header.IPv6(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
+ PayloadLength: uint16(payloadLength),
TransportProtocol: transportProtocol,
HopLimit: test.TTL,
SrcAddr: test.sourceAddr,
@@ -3489,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) {
t.Error(err)
}
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv6EchoRequest)
+ icmpH.SetCode(header.ICMPv6UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH,
+ Src: host2IPv6Addr.AddressWithPrefix.Address,
+ Dst: host1IPv6Addr.AddressWithPrefix.Address,
+ }))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv6ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+
+ // Calculate the UDP checksum and set it.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(nil, sum)
+ udpH.SetChecksum(^udpH.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index bc9cf6999..3e5c438d3 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Note, we will still expect to send a report for the global address's
// solicited node address from the unspecified address as per RFC 3590
// section 4.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ globalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: globalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err)
}
reportCounter++
if got := reportStat.Value(); got != reportCounter {
@@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Adding a link-local address should send a report for its solicited node
// address and globalMulticastAddr.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ linkLocalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err)
}
if dadResolutionTime != 0 {
reportCounter++
@@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 8837d66d8..938427420 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{
+ PEB: stack.FirstPrimaryEndpoint,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ })
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index f0186c64e..8297a7e10 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
@@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
@@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
@@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
@@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) {
checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
))
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
checkDADMsg()
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 1b96b1fb8..26640b7ee 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
+ addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackIPv4Addr,
+ PrefixLen: defaultIPv4PrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, clock
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 009cab643..05b879543 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -146,8 +146,12 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Add default route.
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index c10b19aa0..a72afadda 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -124,13 +124,13 @@ func main() {
log.Fatalf("Bad IP address: %v", addrName)
}
- var addr tcpip.Address
+ var addrWithPrefix tcpip.AddressWithPrefix
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
- addr = tcpip.Address(parsedAddr.To4())
+ addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix()
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
- addr = tcpip.Address(parsedAddr.To16())
+ addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix()
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", addrName)
@@ -176,11 +176,15 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, proto, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address))))
if err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 6bce3af04..b0b2d0afd 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -57,6 +57,11 @@ type SocketOptionsHandler interface {
// OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE.
OnSetReceiveBufferSize(v, oldSz int64) (newSz int64)
+
+ // WakeupWriters is invoked when the send buffer size for an endpoint is
+ // changed. The handler notifies the writers if the send buffer size is
+ // increased with setsockopt(2) for TCP endpoints.
+ WakeupWriters()
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -98,6 +103,9 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
return v
}
+// WakeupWriters implements SocketOptionsHandler.WakeupWriters.
+func (*DefaultSocketOptionsHandler) WakeupWriters() {}
+
// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize.
func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) {
return v
@@ -162,10 +170,14 @@ type SocketOptions struct {
// message is passed with incoming packets.
receiveTClassEnabled uint32
- // receivePacketInfoEnabled is used to specify if more inforamtion is
- // provided with incoming packets such as interface index and address.
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv4 packets.
receivePacketInfoEnabled uint32
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv6 packets.
+ receiveIPv6PacketInfoEnabled uint32
+
// hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
// being written have an IP header and the endpoint should not attach an IP
// header.
@@ -352,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) {
storeAtomicBool(&so.receivePacketInfoEnabled, v)
}
+// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0
+}
+
+// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v)
+}
+
// GetHeaderIncluded gets value for IP_HDRINCL option.
func (so *SocketOptions) GetHeaderIncluded() bool {
return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
@@ -626,6 +648,9 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize)
}
so.sendBufferSize.Store(sendBufferSize)
+ if notify {
+ so.handler.WakeupWriters()
+ }
}
// GetReceiveBufferSize gets value for SO_RCVBUF option.
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index e0847e58a..6c42ab29b 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -85,6 +85,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/transport/tcpconntrack",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ae0bb4ace..7e4b5bf74 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We now promote the address.
for i, s := range a.mu.primary {
if s == addrState {
- switch peb {
+ switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
@@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
case NeverPrimaryEndpoint:
a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
@@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// Acquire the address before returning it.
addrState.mu.refs++
- addrState.mu.deprecated = deprecated
- addrState.mu.configType = configType
+ addrState.mu.deprecated = properties.Deprecated
+ addrState.mu.configType = properties.ConfigType
if attemptAddToPrimary {
- switch peb {
+ switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.mu.primary = append(a.mu.primary, addrState)
@@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
a.mu.primary[0] = addrState
}
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
@@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
- ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
- panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 140f146f6..c55f85743 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
}
{
- ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
- t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err)
}
// We don't need the address endpoint.
ep.DecRef()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..16d295271 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -64,13 +64,21 @@ type tuple struct {
// tupleEntry is used to build an intrusive list of tuples.
tupleEntry
- tupleID
-
// conn is the connection tracking entry this tuple belongs to.
conn *conn
// direction is the direction of the tuple.
direction direction
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ tupleID tupleID
+}
+
+func (t *tuple) id() tupleID {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.tupleID
}
// tupleID uniquely identifies a connection in one direction. It currently
@@ -103,50 +111,43 @@ func (ti tupleID) reply() tupleID {
//
// +stateify savable
type conn struct {
+ ct *ConnTrack
+
// original is the tuple in original direction. It is immutable.
original tuple
- // reply is the tuple in reply direction. It is immutable.
+ // reply is the tuple in reply direction.
reply tuple
- // manip indicates if the packet should be manipulated. It is immutable.
- // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
+ mu sync.RWMutex `state:"nosave"`
+ // Indicates that the connection has been finalized and may handle replies.
+ //
+ // +checklocks:mu
+ finalized bool
+ // manip indicates if the packet should be manipulated.
+ //
+ // +checklocks:mu
manip manipType
-
- // tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb. It is immutable.
- tcbHook Hook
-
- // mu protects all mutable state.
- mu sync.Mutex `state:"nosave"`
// tcb is TCB control block. It is used to keep track of states
- // of tcp connection and is protected by mu.
+ // of tcp connection.
+ //
+ // +checklocks:mu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
- // is updated by each packet on the connection. It is protected by mu.
+ // is updated by each packet on the connection.
//
// TODO(gvisor.dev/issue/5939): do not use the ambient clock.
+ //
+ // +checklocks:mu
lastUsed time.Time `state:".(unixTime)"`
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
const defaultTimeout = 120 * time.Second
- cn.mu.Lock()
- defer cn.mu.Unlock()
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -159,17 +160,30 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
-// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:cn.mu
+func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
- } else if hook == cn.tcbHook {
+ return
+ }
+
+ switch dir {
+ case dirOriginal:
cn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
+ case dirReply:
cn.tcb.UpdateStateInbound(tcpHeader)
+ default:
+ panic(fmt.Sprintf("unhandled dir = %d", dir))
}
}
@@ -194,44 +208,34 @@ type ConnTrack struct {
// It is immutable.
seed uint32
+ mu sync.RWMutex `state:"nosave"`
// mu protects the buckets slice, but not buckets' contents. Only take
// the write lock if you are modifying the slice or saving for S/R.
- mu sync.RWMutex `state:"nosave"`
-
- // buckets is protected by mu.
+ //
+ // +checklocks:mu
buckets []bucket
}
// +stateify savable
type bucket struct {
- // mu protects tuples.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
tuples tupleList
}
-// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
-// TCP header.
-//
-// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
- netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
}
- return tupleID{
- srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
- dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
- netProto: pkt.NetworkProtocolNumber,
- }, nil
+ return nil, false
}
func (ct *ConnTrack) init() {
@@ -240,167 +244,185 @@ func (ct *ConnTrack) init() {
ct.buckets = make([]bucket, numBuckets)
}
-// connFor gets the conn for pkt if it exists, or returns nil
-// if it does not. It returns an error when pkt does not contain a valid TCP
-// header.
-// TODO(gvisor.dev/issue/6168): Support UDP.
-func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil, dirOriginal
+func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
+ netHeader := pkt.Network()
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return nil
}
- return ct.connForTID(tid)
-}
-func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
- bucket := ct.bucket(tid)
- now := time.Now()
+ tid := tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: transportHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
+ }
+
+ bktID := ct.bucket(tid)
ct.mu.RLock()
- defer ct.mu.RUnlock()
- ct.buckets[bucket].mu.Lock()
- defer ct.buckets[bucket].mu.Unlock()
-
- // Iterate over the tuples in a bucket, cleaning up any unused
- // connections we find.
- for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
- // Clean up any timed-out connections we happen to find.
- if ct.reapTupleLocked(other, bucket, now) {
- // The tuple expired.
- continue
- }
- if tid == other.tupleID {
- return other.conn, other.direction
- }
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ now := time.Now()
+ if t := bkt.connForTID(tid, now); t != nil {
+ return t
}
- return nil, dirOriginal
-}
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
+ // Make sure a connection wasn't added between when we last checked the
+ // bucket and acquired the bucket's write lock.
+ if t := bkt.connForTIDRLocked(tid, now); t != nil {
+ return t
}
- if hook != Prerouting && hook != Output {
- return nil
+
+ // This is the first packet we're seeing for the connection. Create an entry
+ // for this new connection.
+ conn := &conn{
+ ct: ct,
+ original: tuple{tupleID: tid, direction: dirOriginal},
+ reply: tuple{tupleID: tid.reply(), direction: dirReply},
+ manip: manipNone,
+ lastUsed: now,
}
+ conn.original.conn = conn
+ conn.reply.conn = conn
+
+ // For now, we only map an entry for the packet's original tuple as NAT may be
+ // performed on this connection. Until the packet goes through all the hooks
+ // and its final address/port is known, we cannot know what the response
+ // packet's addresses/ports will look like.
+ //
+ // This is okay because the destination cannot send its response until it
+ // receives the packet; the packet will only be received once all the hooks
+ // have been performed.
+ //
+ // See (*conn).finalize.
+ bkt.tuples.PushFront(&conn.original)
+ return &conn.original
+}
- replyTID := tid.reply()
- replyTID.srcAddr = address
- replyTID.srcPort = port
+func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
+ bktID := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
- }
- conn = newConn(tid, replyTID, manipDestination, hook)
- ct.insertConn(conn)
- return conn
+ ct.mu.RLock()
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ return bkt.connForTID(tid, time.Now())
}
-func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Input && hook != Postrouting {
- return nil
+func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple {
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ return bkt.connForTIDRLocked(tid, now)
+}
+
+// +checklocks:bkt.mu
+func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple {
+ for other := bkt.tuples.Front(); other != nil; other = other.Next() {
+ if tid == other.id() && !other.conn.timedOut(now) {
+ return other
+ }
}
+ return nil
+}
- replyTID := tid.reply()
- replyTID.dstAddr = address
- replyTID.dstPort = port
+func (ct *ConnTrack) finalize(cn *conn) {
+ tid := cn.reply.id()
+ id := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ ct.mu.RLock()
+ bkt := &ct.buckets[id]
+ ct.mu.RUnlock()
+
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil {
+ // Another connection for the reply already exists. We can't do much about
+ // this so we leave the connection cn represents in a state where it can
+ // send packets but its responses will be mapped to some other connection.
+ // This may be okay if the connection only expects to send packets without
+ // any responses.
+ return
}
- conn = newConn(tid, replyTID, manipSource, hook)
- ct.insertConn(conn)
- return conn
+
+ bkt.tuples.PushFront(&cn.reply)
}
-// insertConn inserts conn into the appropriate table bucket.
-func (ct *ConnTrack) insertConn(conn *conn) {
- // Lock the buckets in the correct order.
- tupleBucket := ct.bucket(conn.original.tupleID)
- replyBucket := ct.bucket(conn.reply.tupleID)
- ct.mu.RLock()
- defer ct.mu.RUnlock()
- if tupleBucket < replyBucket {
- ct.buckets[tupleBucket].mu.Lock()
- ct.buckets[replyBucket].mu.Lock()
- } else if tupleBucket > replyBucket {
- ct.buckets[replyBucket].mu.Lock()
- ct.buckets[tupleBucket].mu.Lock()
- } else {
- // Both tuples are in the same bucket.
- ct.buckets[tupleBucket].mu.Lock()
- }
-
- // Now that we hold the locks, ensure the tuple hasn't been inserted by
- // another thread.
- // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
- alreadyInserted := false
- for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
- if other.tupleID == conn.original.tupleID {
- alreadyInserted = true
- break
+func (cn *conn) finalize() {
+ {
+ cn.mu.RLock()
+ finalized := cn.finalized
+ cn.mu.RUnlock()
+ if finalized {
+ return
}
}
- if !alreadyInserted {
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ cn.mu.Lock()
+ finalized := cn.finalized
+ cn.finalized = true
+ cn.mu.Unlock()
+ if finalized {
+ return
}
- // Unlocking can happen in any order.
- ct.buckets[tupleBucket].mu.Unlock()
- if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
- }
+ cn.ct.finalize(cn)
}
-// handlePacket will manipulate the port and address of the packet if the
-// connection exists. Returns whether, after the packet traverses the tables,
-// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
- if pkt.NatDone {
- return false
+// performNAT setups up the connection for the specified NAT.
+//
+// Generally, only the first packet of a connection reaches this method; other
+// other packets will be manipulated without needing to modify the connection.
+func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) {
+ cn.performNATIfNoop(port, address, dnat)
+ cn.handlePacket(pkt, hook, r)
+}
+
+func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ if cn.finalized {
+ return
}
- switch hook {
- case Prerouting, Input, Output, Postrouting:
- default:
- return false
+ if cn.manip != manipNone {
+ return
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
- return false
+ cn.reply.mu.Lock()
+ defer cn.reply.mu.Unlock()
+
+ if dnat {
+ cn.reply.tupleID.srcAddr = address
+ cn.reply.tupleID.srcPort = port
+ cn.manip = manipDestination
+ } else {
+ cn.reply.tupleID.dstAddr = address
+ cn.reply.tupleID.dstPort = port
+ cn.manip = manipSource
}
+}
- conn, dir := ct.connFor(pkt)
- // Connection not found for the packet.
- if conn == nil {
- // If this is the last hook in the data path for this packet (Input if
- // incoming, Postrouting if outgoing), indicate that a connection should be
- // inserted by the end of this hook.
- return hook == Input || hook == Postrouting
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
+ if pkt.NatDone {
+ return
}
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return
}
+ netHeader := pkt.Network()
+
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
@@ -410,49 +432,58 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
updateSRCFields := false
+ dir := pkt.tuple.direction
+
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
switch hook {
case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
- }
+ if cn.manip == manipDestination && dir == dirOriginal {
+ id := cn.reply.id()
+ newPort = id.srcPort
+ newAddr = id.srcAddr
+ pkt.NatDone = true
+ } else if cn.manip == manipSource && dir == dirReply {
+ id := cn.original.id()
+ newPort = id.srcPort
+ newAddr = id.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- }
+ if cn.manip == manipSource && dir == dirOriginal {
+ id := cn.reply.id()
+ newPort = id.dstPort
+ newAddr = id.dstAddr
+ updateSRCFields = true
+ pkt.NatDone = true
+ } else if cn.manip == manipDestination && dir == dirReply {
+ id := cn.original.id()
+ newPort = id.dstPort
+ newAddr = id.dstAddr
+ updateSRCFields = true
pkt.NatDone = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
+
if !pkt.NatDone {
- return false
+ return
}
fullChecksum := false
updatePseudoHeader := false
switch hook {
- case Prerouting, Input:
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+ case Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
fullChecksum = true
@@ -464,7 +495,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
rewritePacket(
netHeader,
- tcpHeader,
+ transportHeader,
updateSRCFields,
fullChecksum,
updatePseudoHeader,
@@ -472,46 +503,10 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
newAddr,
)
- // Update the state of tcb.
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
// Mark the connection as having been used recently so it isn't reaped.
- conn.lastUsed = time.Now()
+ cn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
-
- return false
-}
-
-// maybeInsertNoop tries to insert a no-op connection entry to keep connections
-// from getting clobbered when replies arrive. It only inserts if there isn't
-// already a connection for pkt.
-//
-// This should be called after traversing iptables rules only, to ensure that
-// pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
- // If there were a rule applying to this packet, it would be marked
- // with NatDone.
- if pkt.NatDone {
- return
- }
-
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
- return
- }
-
- // This is the first packet we're seeing for the TCP connection. Insert
- // the noop entry (an identity mapping) so that the response doesn't
- // get NATed, breaking the connection.
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return
- }
- conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
- ct.insertConn(conn)
+ cn.updateLocked(pkt, dir)
}
// bucket gets the conntrack bucket for a tupleID.
@@ -563,14 +558,15 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
defer ct.mu.RUnlock()
for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
idx = (i + start) % len(ct.buckets)
- ct.buckets[idx].mu.Lock()
- for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ bkt := &ct.buckets[idx]
+ bkt.mu.Lock()
+ for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
checked++
- if ct.reapTupleLocked(tuple, idx, now) {
+ if ct.reapTupleLocked(tuple, idx, bkt, now) {
expired++
}
}
- ct.buckets[idx].mu.Unlock()
+ bkt.mu.Unlock()
}
// We already checked buckets[idx].
idx++
@@ -595,44 +591,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// reapTupleLocked tries to remove tuple and its reply from the table. It
// returns whether the tuple's connection has timed out.
//
-// Preconditions:
-// * ct.mu is locked for reading.
-// * bucket is locked.
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+// Precondition: ct.mu is read locked and bkt.mu is write locked.
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:ct.mu
+// +checklocks:bkt.mu
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool {
if !tuple.conn.timedOut(now) {
return false
}
// To maintain lock order, we can only reap these tuples if the reply
// appears later in the table.
- replyBucket := ct.bucket(tuple.reply())
- if bucket > replyBucket {
+ replyBktID := ct.bucket(tuple.id().reply())
+ if bktID > replyBktID {
return true
}
// Don't re-lock if both tuples are in the same bucket.
- differentBuckets := bucket != replyBucket
- if differentBuckets {
- ct.buckets[replyBucket].mu.Lock()
+ if bktID != replyBktID {
+ replyBkt := &ct.buckets[replyBktID]
+ replyBkt.mu.Lock()
+ removeConnFromBucket(replyBkt, tuple)
+ replyBkt.mu.Unlock()
+ } else {
+ removeConnFromBucket(bkt, tuple)
}
// We have the buckets locked and can remove both tuples.
+ bkt.tuples.Remove(tuple)
+ return true
+}
+
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:b.mu
+func removeConnFromBucket(b *bucket, tuple *tuple) {
if tuple.direction == dirOriginal {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ b.tuples.Remove(&tuple.conn.reply)
} else {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
- }
- ct.buckets[bucket].tuples.Remove(tuple)
-
- // Don't re-unlock if both tuples are in the same bucket.
- if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
+ b.tuples.Remove(&tuple.conn.original)
}
-
- return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,17 +640,22 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
- conn, _ := ct.connForTID(tid)
- if conn == nil {
+ t := ct.connForTID(tid)
+ if t == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip != manipDestination {
+ }
+
+ t.conn.mu.RLock()
+ defer t.conn.mu.RUnlock()
+ if t.conn.manip != manipDestination {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
- return conn.original.dstAddr, conn.original.dstPort, nil
+ id := t.conn.original.id()
+ return id.dstAddr, id.dstPort, nil
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 72f66441f..c2f1f4798 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
- return fwdTestNetDefaultPrefixLen
-}
-
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
@@ -342,6 +338,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p
return n, nil
}
+func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
@@ -380,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
- if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
// NIC 2 has the link address "b", and added the network address 2.
@@ -393,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
- if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
nic, ok := s.nics[2]
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
index 3a20839da..99e5d2df7 100644
--- a/pkg/tcpip/stack/icmp_rate_limit.go
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -16,6 +16,7 @@ package stack
import (
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
const (
@@ -31,11 +32,41 @@ const (
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
type ICMPRateLimiter struct {
- *rate.Limiter
+ limiter *rate.Limiter
+ clock tcpip.Clock
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
-// at which ICMP messages are generated by the stack.
-func NewICMPRateLimiter() *ICMPRateLimiter {
- return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+// at which ICMP messages are generated by the stack. The returned limiter
+// does not apply limits to any ICMP types by default.
+func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter {
+ return &ICMPRateLimiter{
+ clock: clock,
+ limiter: rate.NewLimiter(icmpLimit, icmpBurst),
+ }
+}
+
+// SetLimit sets a new Limit for the limiter.
+func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) {
+ l.limiter.SetLimitAt(l.clock.Now(), limit)
+}
+
+// Limit returns the maximum overall event rate.
+func (l *ICMPRateLimiter) Limit() rate.Limit {
+ return l.limiter.Limit()
+}
+
+// SetBurst sets a new burst size for the limiter.
+func (l *ICMPRateLimiter) SetBurst(burst int) {
+ l.limiter.SetBurstAt(l.clock.Now(), burst)
+}
+
+// Burst returns the maximum burst size.
+func (l *ICMPRateLimiter) Burst() int {
+ return l.limiter.Burst()
+}
+
+// Allow reports whether one ICMP message may be sent now.
+func (l *ICMPRateLimiter) Allow() bool {
+ return l.limiter.AllowN(l.clock.Now(), 1)
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..5808be685 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -264,26 +264,134 @@ const (
chainReturn
)
-// Check runs pkt through the rules for hook. It returns true when the packet
-// should continue traversing the network stack and false when it should be
-// dropped.
+// CheckPrerouting performs the prerouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
//
-// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
- if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
+ const hook = Prerouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
return true
}
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
+}
+
+// CheckInput performs the input hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
+ const hook = Input
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+// CheckForward performs the forward hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+ return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
+}
+
+// CheckOutput performs the output hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
+ const hook = Output
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+}
+
+// CheckPostrouting performs the postrouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
+ const hook = Postrouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool {
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ // IPTables only supports IPv4/IPv6.
+ return true
+ }
+
+ it.mu.RLock()
+ defer it.mu.RUnlock()
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
+ return !it.modified
+}
+
+// check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
it.mu.RLock()
defer it.mu.RUnlock()
- if !it.modified {
- return true
- }
-
- // Packets are manipulated only if connection and matching
- // NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
@@ -300,7 +408,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -311,7 +419,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v {
+ switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -327,21 +435,6 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
}
}
- // If this connection should be tracked, try to add an entry for it. If
- // traversing the nat table didn't end in adding an entry,
- // maybeInsertNoop will add a no-op entry for the connection. This is
- // needeed when establishing connections so that the SYN/ACK reply to an
- // outgoing SYN is delivered to the correct endpoint rather than being
- // redirected by a prerouting rule.
- //
- // From the iptables documentation: "If there is no rule, a `null'
- // binding is created: this usually does not map the packet, but exists
- // to ensure we don't map another stream over an existing one."
- if shouldTrack {
- it.connections.maybeInsertNoop(pkt, hook)
- }
-
- // Every table returned Accept.
return true
}
@@ -375,19 +468,32 @@ func (it *IPTables) startReaper(interval time.Duration) {
}()
}
-// CheckPackets runs pkts through the rules for hook and returns a map of packets that
-// should not go forward.
+// CheckOutputPackets performs the output hook on the packets.
//
-// Preconditions:
-// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// * pkt.NetworkHeader is not nil.
+// Returns a map of packets that must be dropped.
//
-// NOTE: unlike the Check API the returned map contains packets that should be
-// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckOutput(pkt, r, outNicName)
+ })
+}
+
+// CheckPostroutingPackets performs the postrouting hook on the packets.
+//
+// Returns a map of packets that must be dropped.
+//
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckPostrouting(pkt, r, addressEP, outNicName)
+ })
+}
+
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok {
+ if ok := f(pkt); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -407,11 +513,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -428,7 +534,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -454,7 +560,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
@@ -477,16 +583,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr)
+ return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
index 529e02a07..3d3c39c20 100644
--- a/pkg/tcpip/stack/iptables_state.go
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -26,11 +26,15 @@ type unixTime struct {
// saveLastUsed is invoked by stateify.
func (cn *conn) saveLastUsed() unixTime {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
}
// loadLastUsed is invoked by stateify.
func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
cn.lastUsed = time.Unix(unix.second, unix.nano)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..7e5a1672a 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,7 +79,7 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -97,7 +97,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -117,6 +117,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
+ var address tcpip.Address
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
@@ -125,48 +126,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
address = header.IPv6Loopback
}
case Prerouting:
- // No-op, as address is already set correctly.
+ // addressEP is expected to be set for the prerouting hook.
+ address = addressEP.MainAddress().Address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
-
- if hook == Output {
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- udpHeader,
- false, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- rt.Port,
- address,
- )
- } else {
- udpHeader.SetDestinationPort(rt.Port)
- }
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
-
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
+ if t := pkt.tuple; t != nil {
+ t.conn.performNAT(pkt, hook, r, rt.Port, address, true /* dnat */)
+ return RuleAccept, 0
}
- return RuleAccept, 0
+ return RuleDrop, 0
}
// SNATTarget modifies the source port/IP in the outgoing packets.
@@ -179,15 +150,7 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
- // Sanity check.
- if st.NetworkProtocol != pkt.NetworkProtocolNumber {
- panic(fmt.Sprintf(
- "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
- st.NetworkProtocol, pkt.NetworkProtocolNumber))
- }
-
+func snatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -198,6 +161,33 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleDrop, 0
}
+ // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a
+ // different port.
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ case header.TCPProtocolNumber:
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ }
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.performNAT(pkt, hook, r, port, address, false /* dnat */)
+ }
+
+ return RuleAccept, 0
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
@@ -206,37 +196,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
+ return snatAction(pkt, hook, r, st.Port, st.Addr)
+}
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
+// MasqueradeTarget modifies the source port/IP in the outgoing packets.
+type MasqueradeTarget struct {
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ mt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Postrouting:
+ case Prerouting, Input, Forward, Output:
+ panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ // addressEP is expected to be set for the postrouting hook.
+ ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */)
+ if ep == nil {
+ // No address exists that we can use as a source address.
return RuleDrop, 0
}
- return RuleAccept, 0
+ address := ep.AddressWithPrefix().Address
+ ep.DecRef()
+ return snatAction(pkt, hook, r, 0 /* port */, address)
}
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 66e5f22ac..b22024667 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -81,17 +81,6 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects v4Tables, v6Tables, and modified.
- mu sync.RWMutex
- // v4Tables and v6tables map tableIDs to tables. They hold builtin
- // tables only, not user tables. mu must be locked for accessing.
- v4Tables [NumTables]Table
- v6Tables [NumTables]Table
- // modified is whether tables have been modified at least once. It is
- // used to elide the iptables performance overhead for workloads that
- // don't utilize iptables.
- modified bool
-
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
@@ -101,6 +90,21 @@ type IPTables struct {
// reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
+
+ mu sync.RWMutex
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables.
+ //
+ // +checklocks:mu
+ v4Tables [NumTables]Table
+ // +checklocks:mu
+ v6Tables [NumTables]Table
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ //
+ // +checklocks:mu
+ modified bool
}
// VisitTargets traverses all the targets of all tables and replaces each with
@@ -352,5 +356,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int)
+ Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 4d5431da1..40b33b6b5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Should get the address immediately since we should not have performed
@@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Make sure the address does not resolve before the resolution time has
@@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet
@@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) {
// Attempting to add the address again should not fail if the address's
// state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
})
}
@@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) {
// Add addresses for each NIC.
addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix1,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err)
}
addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix2,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err)
}
expectDADEvent(nicID2, addr2)
addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix3,
+ }
+ if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err)
}
expectDADEvent(nicID3, addr3)
@@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
continue
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: test.addrs[j].Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{}
@@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addr2,
}
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
// Add the address as a static address before SLAAC tries to add it.
- if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err)
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
@@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) {
}
if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index b854d868c..29d580e76 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -72,9 +72,15 @@ type nic struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained packetEndpointList are
- // not.
- packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
+ }
+
+ packetEPs struct {
+ mu sync.RWMutex
+
+ // eps is protected by the mutex, but the values contained in it are not.
+ //
+ // +checklocks:mu
+ eps map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
@@ -91,6 +97,8 @@ type packetEndpointList struct {
mu sync.RWMutex
// eps is protected by mu, but the contained PacketEndpoint values are not.
+ //
+ // +checklocks:mu
eps []PacketEndpoint
}
@@ -111,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) {
}
}
+func (p *packetEndpointList) len() int {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return len(p.eps)
+}
+
// forEach calls fn with each endpoints in p while holding the read lock on p.
func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
p.mu.RLock()
@@ -143,18 +157,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
}
nic.linkResQueue.init(nic)
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
+
+ nic.packetEPs.mu.Lock()
+ defer nic.packetEPs.mu.Unlock()
+
+ nic.packetEPs.eps = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0
- // Register supported packet and network endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = new(packetEndpointList)
- }
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = new(packetEndpointList)
-
netEP := netProto.NewEndpoint(nic, nic)
nic.networkEndpoints[netNum] = netEP
@@ -365,6 +377,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
+
if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil {
return err
}
@@ -383,6 +397,7 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
}
writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
@@ -501,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -512,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return &tcpip.ErrNotSupported{}
}
- addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -699,12 +714,9 @@ func (n *nic) isInGroup(addr tcpip.Address) bool {
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
if !enabled {
- n.mu.RUnlock()
-
n.stats.disabledRx.packets.Increment()
n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size()))
return
@@ -715,7 +727,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
- n.mu.RUnlock()
n.stats.unknownL3ProtocolRcvdPackets.Increment()
return
}
@@ -727,44 +738,87 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
- // Are any packet type sockets listening for this network protocol?
- protoEPs := n.mu.packetEPs[protocol]
- // Other packet type sockets that are listening for all protocols.
- anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
-
// Deliver to interested packet endpoints without holding NIC lock.
+ var packetEPPkt *PacketBuffer
deliverPacketEPs := func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketHost
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(),
+ })
+ // If a link header was populated in the original packet buffer, then
+ // populate it in the packet buffer we provide to packet endpoints as
+ // packet endpoints inspect link headers.
+ packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size())
+ packetEPPkt.PktType = tcpip.PacketHost
+ }
+
+ ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone())
}
- if protoEPs != nil {
+
+ n.packetEPs.mu.Lock()
+ // Are any packet type sockets listening for this network protocol?
+ protoEPs, protoEPsOK := n.packetEPs.eps[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs, anyEPsOK := n.packetEPs.eps[header.EthernetProtocolAll]
+ n.packetEPs.mu.Unlock()
+
+ if protoEPsOK {
protoEPs.forEach(deliverPacketEPs)
}
- if anyEPs != nil {
+ if anyEPsOK {
anyEPs.forEach(deliverPacketEPs)
}
networkEndpoint.HandlePacket(pkt)
}
-// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
-func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
+// deliverOutboundPacket delivers outgoing packets to interested endpoints.
+func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ n.packetEPs.mu.RLock()
+ defer n.packetEPs.mu.RUnlock()
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- eps := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
+ eps, ok := n.packetEPs.eps[header.EthernetProtocolAll]
+ if !ok {
+ return
+ }
+
+ local := n.LinkAddress()
+ var packetEPPkt *PacketBuffer
eps.forEach(func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketOutgoing
- // Add the link layer header as outgoing packets are intercepted
- // before the link layer header is created.
- n.LinkEndpoint.AddHeader(local, remote, protocol, p)
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: pkt.AvailableHeaderBytes(),
+ Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ })
+ // Add the link layer header as outgoing packets are intercepted before
+ // the link layer header is created and packet endpoints are interested
+ // in the link header.
+ n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt)
+ packetEPPkt.PktType = tcpip.PacketOutgoing
+ }
+
+ ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone())
})
}
@@ -917,12 +971,13 @@ func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigura
}
func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
- return &tcpip.ErrNotSupported{}
+ eps = new(packetEndpointList)
+ n.packetEPs.eps[netProto] = eps
}
eps.add(ep)
@@ -930,14 +985,17 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
}
func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
return
}
eps.remove(ep)
+ if eps.len() == 0 {
+ delete(n.packetEPs.eps, netProto)
+ }
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5cb342f78..c8ad93f29 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
-func (*testIPv6Protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 9192d8433..888a8bd9d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -143,6 +143,8 @@ type PacketBuffer struct {
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
+
+ tuple *tuple
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -282,14 +284,12 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View {
return v
}
-// Clone makes a shallow copy of pk.
-//
-// Clone should be called in such cases so that no modifications is done to
-// underlying packet payload.
+// Clone makes a semi-deep copy of pk. The underlying packet payload is
+// shared. Hence, no modifications is done to underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- buf: pk.buf,
+ buf: pk.buf.Clone(),
reserved: pk.reserved,
pushed: pk.pushed,
consumed: pk.consumed,
@@ -304,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NICID: pk.NICID,
RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
NetworkPacketInfo: pk.NetworkPacketInfo,
+ tuple: pk.tuple,
}
}
@@ -321,25 +322,51 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
-// CloneToInbound makes a shallow copy of the packet buffer to be used as an
-// inbound packet.
+// CloneToInbound makes a semi-deep copy of the packet buffer (similar to
+// Clone) to be used as an inbound packet.
//
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
newPk := &PacketBuffer{
- buf: pk.buf,
+ buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
+ tuple: pk.tuple,
+ }
+ return newPk
+}
+
+// DeepCopyForForwarding creates a deep copy of the packet buffer for
+// forwarding.
+//
+// The returned packet buffer will have the network and transport headers
+// set if the original packet buffer did.
+func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
+ newPk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: reservedHeaderBytes,
+ Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
+ IsForwardedPacket: true,
+ })
+
+ {
+ consumeBytes := pk.NetworkHeader().View().Size()
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
+ }
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- if pk.NatDone {
- newPk.NatDone = true
+
+ {
+ consumeBytes := pk.TransportHeader().View().Size()
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
+ }
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
+
+ newPk.tuple = pk.tuple
+
return newPk
}
@@ -391,13 +418,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// DeleteFront removes count from the beginning of d. It panics if count >
-// d.Size(). All backing storage references after the front of the d are
-// invalidated.
-func (d PacketData) DeleteFront(count int) {
- if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
- panic("count > d.Size()")
+// Consume is the same as PullUp except that is additionally consumes the
+// returned bytes. Subsequent PullUp or Consume will not return these bytes.
+func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) {
+ v, ok := d.PullUp(size)
+ if ok {
+ d.pk.consumed += size
}
+ return v, ok
}
// CapLength reduces d to at most length bytes.
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index a8da34992..c376ed1a1 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -435,11 +435,17 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // DeleteFront
+ // Consume.
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
+ v, ok := pkt.Data().Consume(n)
+ if !ok {
+ t.Fatalf("Consume failed")
+ }
+ if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) {
+ t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want)
+ }
checkData(t, pkt, []byte(tc.data)[n:])
})
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index dfe2c886f..31b3a554d 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int
const (
// CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
+ // endpoint for new connections with no local address.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
// FirstPrimaryEndpoint indicates the endpoint should be the first
@@ -332,6 +331,19 @@ const (
NeverPrimaryEndpoint
)
+func (peb PrimaryEndpointBehavior) String() string {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return "CanBePrimaryEndpoint"
+ case FirstPrimaryEndpoint:
+ return "FirstPrimaryEndpoint"
+ case NeverPrimaryEndpoint:
+ return "NeverPrimaryEndpoint"
+ default:
+ panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb))
+ }
+}
+
// AddressConfigType is the method used to add an address.
type AddressConfigType int
@@ -351,6 +363,14 @@ const (
AddressConfigSlaacTemp
)
+// AddressProperties contains additional properties that can be configured when
+// adding an address.
+type AddressProperties struct {
+ PEB PrimaryEndpointBehavior
+ ConfigType AddressConfigType
+ Deprecated bool
+}
+
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
@@ -457,7 +477,7 @@ type AddressableEndpoint interface {
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
@@ -685,9 +705,6 @@ type NetworkProtocol interface {
// than this targeted at this protocol.
MinimumPacketSize() int
- // DefaultPrefixLen returns the protocol's default prefix length.
- DefaultPrefixLen() int
-
// ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
@@ -733,16 +750,6 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
-
- // DeliverOutboundPacket is called by link layer when a packet is being
- // sent out.
- //
- // pkt.LinkHeader may or may not be set before calling
- // DeliverOutboundPacket. Some packets do not have link headers (e.g.
- // packets sent via loopback), and won't have the field set.
- //
- // DeliverOutboundPacket takes ownership of pkt.
- DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -846,6 +853,14 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link.
+ //
+ // If the link-layer has its own header, the payload must already include the
+ // header.
+ //
+ // WriteRawPacket takes ownership of the packet.
+ WriteRawPacket(*PacketBuffer) tcpip.Error
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c73890c4c..428350f31 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -72,7 +72,8 @@ type Stack struct {
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
- rawFactory RawFactory
+ rawFactory RawFactory
+ packetEndpointWriteSupported bool
demux *transportDemuxer
@@ -119,8 +120,7 @@ type Stack struct {
// by the stack.
icmpRateLimiter *ICMPRateLimiter
- // seed is a one-time random value initialized at stack startup
- // and is used to seed the TCP port picking on active connections
+ // seed is a one-time random value initialized at stack startup.
//
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
@@ -161,6 +161,10 @@ type Stack struct {
// This is required to prevent potential ACK loops.
// Setting this to 0 will disable all rate limiting.
tcpInvalidRateLimit time.Duration
+
+ // tsOffsetSecret is the secret key for generating timestamp offsets
+ // initialized at stack startup.
+ tsOffsetSecret uint32
}
// UniqueID is an abstract generator of unique identifiers.
@@ -215,6 +219,10 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
+ // AllowPacketEndpointWrite determines if packet endpoints support write
+ // operations.
+ AllowPacketEndpointWrite bool
+
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
// returned by the stack secure RNG.
@@ -356,23 +364,24 @@ func New(opts Options) *Stack {
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: seed,
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: randomGenerator,
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ packetEndpointWriteSupported: opts.AllowPacketEndpointWrite,
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(clock),
+ seed: seed,
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: randomGenerator,
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -384,6 +393,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
tcpInvalidRateLimit: defaultTCPInvalidRateLimit,
+ tsOffsetSecret: randomGenerator.Uint32(),
}
// Add specified network protocols.
@@ -906,46 +916,9 @@ type NICStateFlags struct {
Loopback bool
}
-// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
- return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
-// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
- ap := tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: addr,
- }
- return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
-}
-
-// AddProtocolAddress adds a new network-layer protocol address to the
-// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
- return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify
-// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return &tcpip.ErrUnknownProtocol{}
- }
- return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb)
-}
-
-// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
-// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+// AddProtocolAddress adds an address to the specified NIC, possibly with extra
+// properties.
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -954,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return &tcpip.ErrUnknownNICID{}
}
- return nic.addAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, properties)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1649,9 +1622,27 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress,
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
Data: payload,
})
+ pkt.NetworkProtocolNumber = netProto
return nic.WritePacketToRemote(remote, netProto, pkt)
}
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
+ }
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ Data: payload,
+ })
+ pkt.NetworkProtocolNumber = proto
+ return nic.WriteRawPacket(pkt)
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1819,8 +1810,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol
return nic.setNUDConfigs(proto, c)
}
-// Seed returns a 32 bit value that can be used as a seed value for port
-// picking, ISN generation etc.
+// Seed returns a 32 bit value that can be used as a seed value.
//
// NOTE: The seed is generated once during stack initialization only.
func (s *Stack) Seed() uint32 {
@@ -1944,3 +1934,9 @@ func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProto
return false
}
+
+// PacketEndpointWriteSupported returns true iff packet endpoints support write
+// operations.
+func (s *Stack) PacketEndpointWriteSupported() bool {
+ return s.packetEndpointWriteSupported
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 3089c0ef4..c23e91702 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().Consume(fakeNetHeaderLen)
if !ok {
return
}
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]),
// Nothing checks the error.
nil, /* transport error */
pkt,
@@ -234,10 +231,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (*fakeNetworkProtocol) DefaultPrefixLen() int {
- return fakeDefaultPrefixLen
-}
-
func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
return f.packetCount[int(intfAddr)%len(f.packetCount)]
}
@@ -349,12 +342,26 @@ func TestNetworkReceive(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -517,8 +524,15 @@ func TestNetworkSend(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Make sure that the link-layer endpoint received the outbound packet.
@@ -538,12 +552,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -551,12 +579,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -812,8 +854,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
ep2 := channel.New(1, defaultMTU, "")
@@ -821,8 +870,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr2,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
// Set a route table that sends all packets with odd destination
@@ -978,12 +1034,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -991,12 +1061,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -1058,8 +1142,15 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1108,8 +1199,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1242,8 +1340,15 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1270,8 +1375,8 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1310,8 +1415,8 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1453,8 +1558,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
@@ -1510,8 +1622,15 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
@@ -1633,8 +1752,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
- if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
+ if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
@@ -1678,13 +1797,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
t.Fatalf("CreateNIC failed: %s", err)
}
nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1726,7 +1845,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1},
},
rt...,
)
@@ -1808,8 +1927,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: anyAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
@@ -1886,22 +2012,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Add an address and in case of a primary one include a
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
+ properties := stack.AddressProperties{PEB: behavior}
if behavior == stack.CanBePrimaryEndpoint {
protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: addrLen * 8,
- },
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: address.WithPrefix(),
}
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
}
}
@@ -1996,8 +2127,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
PrefixLen: tc.prefixLen,
},
}
- if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
- t.Fatal("AddProtocolAddress failed:", err)
+ if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -2047,33 +2178,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
}
-func TestAddAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
- for _, addrLen := range []int{4, 16} {
- address := addrGen.next(addrLen)
- if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
@@ -2084,96 +2188,43 @@ func TestAddProtocolAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- var addrGen addressGenerator
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
addrLenRange := []int{4, 16}
behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp}
+ deprecatedRange := []bool{false, true}
+ wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange))
var addrGen addressGenerator
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- for _, behavior := range behaviorRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ for _, configType := range configTypeRange {
+ for _, deprecated := range deprecatedRange {
+ address := addrGen.next(addrLen)
+ properties := stack.AddressProperties{
+ PEB: behavior,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err)
+ }
+ wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
+ })
}
- expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
+ verifyAddresses(t, wantAddresses, gotAddresses)
}
func TestCreateNICWithOptions(t *testing.T) {
@@ -2290,8 +2341,15 @@ func TestNICStats(t *testing.T) {
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed: ", err)
}
- if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nic.addr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err)
}
{
@@ -2735,8 +2793,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// be returned by a call to GetMainNICAddress;
// else, it should.
const address1 = tcpip.Address("\x01")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ properties := stack.AddressProperties{PEB: pi}
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err)
}
addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
if err != nil {
@@ -2785,16 +2851,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
const address3 = tcpip.Address("\x03")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
-
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address3,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: ps}
+ if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
@@ -3096,8 +3177,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
for _, a := range test.nicAddrs {
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
- t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: a.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -3203,8 +3288,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// The NIC should have joined addr1's solicited node multicast address.
@@ -3359,8 +3448,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
PrefixLen: 128,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
// Address should be in the list of all addresses.
@@ -3687,8 +3776,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
@@ -3750,8 +3839,8 @@ func TestResolveWith(t *testing.T) {
PrefixLen: 24,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
@@ -3792,8 +3881,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3881,8 +3977,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -3990,44 +4086,44 @@ func TestFindRouteWithForwarding(t *testing.T) {
)
type netCfg struct {
- proto tcpip.NetworkProtocolNumber
- factory stack.NetworkProtocolFactory
- nic1Addr tcpip.Address
- nic2Addr tcpip.Address
- remoteAddr tcpip.Address
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1AddrWithPrefix tcpip.AddressWithPrefix
+ nic2AddrWithPrefix tcpip.AddressWithPrefix
+ remoteAddr tcpip.Address
}
fakeNetCfg := netCfg{
- proto: fakeNetNumber,
- factory: fakeNetFactory,
- nic1Addr: nic1Addr,
- nic2Addr: nic2Addr,
- remoteAddr: remoteAddr,
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen},
+ nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen},
+ remoteAddr: remoteAddr,
}
globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: llAddr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: globalIPv6Addr1,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: llAddr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: globalIPv6Addr1,
}
ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: llAddr1,
- remoteAddr: llAddr2,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: llAddr1.WithPrefix(),
+ remoteAddr: llAddr2,
}
ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
}
tests := []struct {
@@ -4036,8 +4132,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg netCfg
forwardingEnabled bool
- addrNIC tcpip.NICID
- localAddr tcpip.Address
+ addrNIC tcpip.NICID
+ localAddrWithPrefix tcpip.AddressWithPrefix
findRouteErr tcpip.Error
dependentOnForwarding bool
@@ -4047,7 +4143,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4056,7 +4152,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4065,7 +4161,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4074,7 +4170,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4083,7 +4179,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4092,7 +4188,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4101,7 +4197,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4110,7 +4206,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4118,7 +4214,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4126,7 +4222,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4134,7 +4230,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4142,7 +4238,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: true,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4166,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on different NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4174,7 +4270,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4182,7 +4278,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4190,7 +4286,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4198,7 +4294,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4206,7 +4302,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4214,7 +4310,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4222,7 +4318,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4230,7 +4326,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4238,7 +4334,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4246,7 +4342,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4268,12 +4364,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
}
- if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic1AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
- if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic2AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
@@ -4282,20 +4386,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
- r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
- t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
return
}
- if r.LocalAddress() != test.localAddr {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr)
+ if r.LocalAddress() != test.localAddrWithPrefix.Address {
+ t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address)
}
if r.RemoteAddress() != test.netCfg.remoteAddr {
t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr)
@@ -4318,8 +4422,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if !ok {
t.Fatal("packet not sent through ep2")
}
- if pkt.Route.LocalAddress != test.localAddr {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address)
}
if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index 90a8ba6cf..a941091b0 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -288,6 +289,12 @@ type TCPSenderState struct {
// RACKState holds the state related to RACK loss detection algorithm.
RACKState TCPRACKState
+
+ // RetransmitTS records the timestamp used to detect spurious recovery.
+ RetransmitTS uint32
+
+ // SpuriousRecovery indicates if the sender entered recovery spuriously.
+ SpuriousRecovery bool
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
@@ -386,6 +393,12 @@ type TCPSndBufState struct {
// SndMTU is the smallest MTU seen in the control packets received.
SndMTU int
+
+ // AutoTuneSndBufDisabled indicates that the auto tuning of send buffer
+ // is disabled.
+ //
+ // Must be accessed using atomic operations.
+ AutoTuneSndBufDisabled uint32
}
// TCPEndpointStateInner contains the members of TCPEndpointState used directly
@@ -396,7 +409,7 @@ type TCPSndBufState struct {
type TCPEndpointStateInner struct {
// TSOffset is a randomized offset added to the value of the TSVal
// field in the timestamp option.
- TSOffset uint32
+ TSOffset tcp.TSOffset
// SACKPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index dda57e225..542d9257c 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -32,11 +32,13 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
- // mu protects all fields of the transportEndpoints.
- mu sync.RWMutex
+ mu sync.RWMutex
+ // +checklocks:mu
endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
+ //
+ // +checklocks:mu
rawEndpoints []RawTransportEndpoint
}
@@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo
}
type endpointsByNIC struct {
- mu sync.RWMutex
- endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
+
+ mu sync.RWMutex
+ // +checklocks:mu
+ endpoints map[tcpip.NICID]*multiPortEndpoint
}
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
@@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
- transEP := selectEndpoint(id, mpep, epsByNIC.seed)
+ transEP := mpep.selectEndpoint(id, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
@@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt)
+ mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
@@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
demux *transportDemuxer
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
+ flags ports.FlagCounter
+
+ mu sync.RWMutex `state:"nosave"`
// endpoints stores the transport endpoints in the order in which they
// were bound. This is required for UDP SO_REUSEADDR.
+ //
+ // +checklocks:mu
endpoints []TransportEndpoint
- flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpoints) == 1 {
- return mpep.endpoints[0]
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ if len(ep.endpoints) == 1 {
+ return ep.endpoints[0]
}
- if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
- return mpep.endpoints[len(mpep.endpoints)-1]
+ if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
+ return ep.endpoints[len(ep.endpoints)-1]
}
payload := []byte{
@@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
- return mpep.endpoints[idx]
+ idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
+ return ep.endpoints[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
@@ -479,7 +489,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: d.stack.Seed(),
+ seed: d.stack.seed,
}
}
if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
@@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
}
}
- ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ ep := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
return ep
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 45b09110d..cd3a8c25a 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -35,7 +35,7 @@ import (
const (
testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
@@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: testDstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 839178809..655931715 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -357,8 +357,15 @@ func TestTransportReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -428,8 +435,15 @@ func TestTransportControlReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -497,8 +511,15 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 55683b4fb..460a6afaf 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -19,7 +19,7 @@
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
-// adding network addresses (via calls to Stack.AddAddress()), and
+// adding network addresses (via calls to Stack.AddProtocolAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
@@ -423,9 +423,9 @@ type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
@@ -451,6 +451,12 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // IPv6PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo IPv6PacketInfo
+
// HasOriginalDestinationAddress indicates whether OriginalDstAddress is
// set.
HasOriginalDstAddress bool
@@ -465,10 +471,10 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns KUID of the packet.
+ // KUID returns KUID of the packet.
KUID() uint32
- // GID returns KGID of the packet.
+ // KGID returns KGID of the packet.
KGID() uint32
}
@@ -1164,6 +1170,14 @@ type IPPacketInfo struct {
DestinationAddr Address
}
+// IPv6PacketInfo is the message structure for IPV6_PKTINFO.
+//
+// +stateify savable
+type IPv6PacketInfo struct {
+ Addr Address
+ NIC NICID
+}
+
// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max send buffer sizes.
type SendBufferSizeOption struct {
@@ -1231,11 +1245,11 @@ type Route struct {
// String implements the fmt.Stringer interface.
func (r Route) String() string {
var out strings.Builder
- fmt.Fprintf(&out, "%s", r.Destination)
+ _, _ = fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
- fmt.Fprintf(&out, " via %s", r.Gateway)
+ _, _ = fmt.Fprintf(&out, " via %s", r.Gateway)
}
- fmt.Fprintf(&out, " nic %d", r.NIC)
+ _, _ = fmt.Fprintf(&out, " nic %d", r.NIC)
return out.String()
}
@@ -1255,6 +1269,8 @@ type TransportProtocolNumber uint32
type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
+//
+// +stateify savable
type StatCounter struct {
count atomicbitops.AlignedAtomicUint64
}
@@ -1270,7 +1286,7 @@ func (s *StatCounter) Decrement() {
}
// Value returns the current value of the counter.
-func (s *StatCounter) Value(name ...string) uint64 {
+func (s *StatCounter) Value(...string) uint64 {
return s.count.Load()
}
@@ -1849,6 +1865,10 @@ type TCPStats struct {
// SegmentsAckedWithDSACK is the number of segments acknowledged with
// DSACK.
SegmentsAckedWithDSACK *StatCounter
+
+ // SpuriousRecovery is the number of times the connection entered loss
+ // recovery spuriously.
+ SpuriousRecovery *StatCounter
}
// UDPStats collects UDP-specific stats.
@@ -1981,6 +2001,8 @@ type Stats struct {
}
// ReceiveErrors collects packet receive errors within transport endpoint.
+//
+// +stateify savable
type ReceiveErrors struct {
// ReceiveBufferOverflow is the number of received packets dropped
// due to the receive buffer being full.
@@ -1998,8 +2020,10 @@ type ReceiveErrors struct {
ChecksumErrors StatCounter
}
-// SendErrors collects packet send errors within the transport layer for
-// an endpoint.
+// SendErrors collects packet send errors within the transport layer for an
+// endpoint.
+//
+// +stateify savable
type SendErrors struct {
// SendToNetworkFailed is the number of packets failed to be written to
// the network endpoint.
@@ -2010,6 +2034,8 @@ type SendErrors struct {
}
// ReadErrors collects segment read errors from an endpoint read call.
+//
+// +stateify savable
type ReadErrors struct {
// ReadClosed is the number of received packet drops because the endpoint
// was shutdown for read.
@@ -2025,6 +2051,8 @@ type ReadErrors struct {
}
// WriteErrors collects packet write errors from an endpoint write call.
+//
+// +stateify savable
type WriteErrors struct {
// WriteClosed is the number of packet drops because the endpoint
// was shutdown for write.
@@ -2040,6 +2068,8 @@ type WriteErrors struct {
}
// TransportEndpointStats collects statistics about the endpoint.
+//
+// +stateify savable
type TransportEndpointStats struct {
// PacketsReceived is the number of successful packet receives.
PacketsReceived StatCounter
diff --git a/pkg/tcpip/tcpip_state.go b/pkg/tcpip/tcpip_state.go
new file mode 100644
index 000000000..1953e24a1
--- /dev/null
+++ b/pkg/tcpip/tcpip_state.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "time"
+)
+
+func (c *ControlMessages) saveTimestamp() int64 {
+ return c.Timestamp.UnixNano()
+}
+
+func (c *ControlMessages) loadTimestamp(nsec int64) {
+ c.Timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 181ef799e..7c998eaae 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -34,12 +34,16 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
"//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 92fa6257d..6e1d4720d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) {
addr: utils.RouterNIC2IPv6Addr,
},
} {
- if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index f9ab7d0af..f01e2b128 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -15,19 +15,24 @@
package iptables_test
import (
+ "bytes"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type inputIfNameMatcher struct {
@@ -49,10 +54,10 @@ const (
nicName = "nic1"
anotherNicName = "nic2"
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = "\x0a\x00\x00\x01"
- dstAddrV4 = "\x0a\x00\x00\x02"
- srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
+ dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
+ srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
payloadSize = 20
)
@@ -66,8 +71,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: dstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -82,8 +91,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: dstAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -601,11 +614,19 @@ func TestIPTableWritePackets(t *testing.T) {
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: srcAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
+ }
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: srcAddrV4.WithPrefix(),
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -856,11 +877,19 @@ func TestForwardingHook(t *testing.T) {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1037,22 +1066,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
}
e2 := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1132,3 +1161,312 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
})
}
}
+
+func TestSNAT(t *testing.T) {
+ const listenPort = 8080
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+
+ nattedClientAddr tcpip.Address
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+ t.Cleanup(ep.Close)
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ {
+ name: "IPv6 host1 server with host2 client",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr tcpip.Error
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
+ }
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
+ t.Helper()
+
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
+ }
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ return newEP, newCH
+ }
+ },
+ needRemoteAddr: false,
+ },
+ }
+
+ setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.NATID, ipv6)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = target
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+
+ natTypes := []struct {
+ name string
+ setupNAT func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber, tcpip.Address)
+ }{
+ {
+ name: "SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
+ t.Helper()
+
+ setupNAT(t, s, netProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: natToAddr})
+ },
+ },
+ {
+ name: "Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
+ t.Helper()
+
+ setupNAT(t, s, netProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ for _, natType := range natTypes {
+ t.Run(natType.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+
+ natType.setupNAT(t, routerStack, test.netProto, epsAndAddrs.nattedClientAddr)
+
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
+ }
+ nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ nattedClientAddr.Port = addr.Port
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, nattedClientAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 27caa0c28..95ddd8ec3 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
@@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.incomingAddr,
}
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
}
// Set up endpoint through which we will attempt to forward packets.
@@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.outgoingAddr,
}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index b2008f0b2..f33223e79 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
+ }
+ v6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
}
if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ipv4Loopback,
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if test.forwarding {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 2d0a6e6a7..7753e7d6e 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
@@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
// Set the route table so that UDP can find a NIC that is
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index ac3c703d4..422eb8408 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) {
// request/reply packets.
icmpDataOffset = 8
)
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
+ ipv4Loopback := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse4("127.0.0.1"),
+ PrefixLen: 8,
+ }
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
@@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
icmpBuf func(*testing.T) buffer.View
expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
@@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback,
+ localAddr: header.IPv6Loopback.WithPrefix(),
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
@@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr.Address,
+ localAddr: utils.Ipv4Addr,
icmpBuf: ipv4ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr.Address,
+ localAddr: utils.Ipv6Addr,
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ if len(test.localAddr.Address) != 0 {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.netProto,
+ AddressWithPrefix: test.localAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) {
}
defer ep.Close()
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
if err := ep.Connect(connAddr); err != test.expectedConnectErr {
t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
}
@@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) {
if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
- if rr.RemoteAddr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ if rr.RemoteAddr.Addr != test.localAddr.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
}
test.checkLinkEndpoint(t, e)
@@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) {
}
if subTest.addAddress {
- if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
}
- if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
}
}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 2e6ae55ea..c69410859 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -40,6 +40,14 @@ const (
Host2NICID = 4
)
+// Common NIC names used by tests.
+const (
+ Host1NICName = "host1NIC"
+ RouterNIC1Name = "routerNIC1"
+ RouterNIC2Name = "routerNIC2"
+ Host2NICName = "host2NIC"
+)
+
// Common link addresses used by tests.
const (
LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
@@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host1NICName}
+ if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
+ t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC1Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC2Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
+ }
}
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host2NICName}
+ if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
+ t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
+ }
}
if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -231,29 +251,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD
new file mode 100644
index 000000000..af332ed91
--- /dev/null
+++ b/pkg/tcpip/transport/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "transport",
+ srcs = [
+ "datagram.go",
+ "transport.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip"],
+)
diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go
new file mode 100644
index 000000000..dfce72c69
--- /dev/null
+++ b/pkg/tcpip/transport/datagram.go
@@ -0,0 +1,49 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// DatagramEndpointState is the state of a datagram-based endpoint.
+type DatagramEndpointState tcpip.EndpointState
+
+// The states a datagram-based endpoint may be in.
+const (
+ _ DatagramEndpointState = iota
+ DatagramEndpointStateInitial
+ DatagramEndpointStateBound
+ DatagramEndpointStateConnected
+ DatagramEndpointStateClosed
+)
+
+// String implements fmt.Stringer.
+func (s DatagramEndpointState) String() string {
+ switch s {
+ case DatagramEndpointStateInitial:
+ return "INITIAL"
+ case DatagramEndpointStateBound:
+ return "BOUND"
+ case DatagramEndpointStateConnected:
+ return "CONNECTED"
+ case DatagramEndpointStateClosed:
+ return "CLOSED"
+ default:
+ panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s))
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index bbc0e3ecc..4718ec4ec 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index f9a15efb2..31579a896 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,7 @@
package icmp
import (
+ "fmt"
"io"
"time"
@@ -24,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -35,15 +38,6 @@ type icmpPacket struct {
receivedAt time.Time `state:".(int64)"`
}
-type endpointState int
-
-const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
-)
-
// endpoint represents an ICMP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -51,14 +45,17 @@ const (
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -70,38 +67,23 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
- state endpointState
- route *stack.Route `state:"manual"`
- ttl uint8
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
+ ident uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
- state: stateInitial,
uniqueID: s.UniqueID(),
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ ep.net.Init(s, netProto, transProto, &ep.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -128,35 +110,40 @@ func (e *endpoint) Abort() {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
- e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
- case stateBound, stateConnected:
- bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
- }
-
- // Close the receive list and drain it.
- e.rcvMu.Lock()
- e.rcvClosed = true
- e.rcvBufSize = 0
- for !e.rcvList.Empty() {
- p := e.rcvList.Front()
- e.rcvList.Remove(p)
- }
- e.rcvMu.Unlock()
+ notify := func() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ return false
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice()))
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
+ e.net.Shutdown()
+ e.net.Close()
- // Update the state.
- e.state = stateClosed
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
- e.mu.Unlock()
+ return true
+ }()
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ if notify {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ }
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
@@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {}
// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -193,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: p.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -214,13 +201,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
//
// Returns true for retry if preparation should be retried.
// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.state {
- case stateInitial:
- case stateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
-
- case stateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
- return 0, err
+ return network.WriteContext{}, 0, err
}
if !retry {
@@ -298,53 +272,35 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
}
- route := e.route
- if to != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := to.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
- dst, netProto, err := e.checkV4MappedLocked(*to)
- if err != nil {
- return 0, err
- }
-
- // Find the endpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return 0, err
- }
- defer r.Release()
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ return ctx, e.ident, err
+}
- route = r
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ ctx, ident, err := e.prepareForWrite(opts)
+ if err != nil {
+ return 0, err
}
+ defer ctx.Release()
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
return 0, &tcpip.ErrBadBuffer{}
}
- var err tcpip.Error
- switch e.NetProto {
+ switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
+ if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
case header.IPv6ProtocolNumber:
- err = send6(route, e.ID.LocalPort, v, e.ttl)
- }
-
- if err != nil {
- return 0, err
+ if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
+ default:
+ panic(fmt.Sprintf("unhandled network protocol = %d", netProto))
}
return int64(len(v)), nil
@@ -357,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool {
return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt sets a socket option.
-func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
- return nil
+// SetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ return e.net.SetSockOpt(opt)
}
-// SetSockOptInt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- }
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -387,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.rcvMu.Lock()
- v := int(e.ttl)
- e.rcvMu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+// GetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
+func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength),
})
- pkt.Owner = owner
icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
@@ -426,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
-
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
-
pkt.Data().AppendView(data)
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V4.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
+func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength),
})
icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
@@ -468,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
pkt.Data().AppendView(data)
dataRange := pkt.Data().AsRange()
icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpv6,
- Src: r.LocalAddress(),
- Dst: r.RemoteAddress(),
+ Src: src,
+ Dst: dst,
PayloadCsum: dataRange.Checksum(),
PayloadLen: dataRange.Size(),
}))
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V6.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
+ return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -515,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- localPort := uint16(0)
- switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.ident
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ nextID, err := e.registerWithStack(netProto, nextID)
+ if err != nil {
+ return err
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress(),
- LocalPort: localPort,
- RemoteAddress: r.RemoteAddress(),
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- id, err = e.registerWithStack(nicID, netProtos, id)
+ e.ident = nextID.LocalPort
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- e.ID = id
- e.route = r
- e.RegisterNICID = nicID
-
- e.state = stateConnected
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -585,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- e.shutdownFlags |= flags
- if e.state != stateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
}
if flags&tcpip.ShutdownRead != 0 {
@@ -615,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
- return id, err
+ return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
+ err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
@@ -644,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err := e.registerWithStack(boundNetProto, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ e.ident = id.LocalPort
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.RegisterNICID = addr.NIC
-
- // Mark endpoint as bound.
- e.state = stateBound
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -687,21 +567,24 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
return nil
}
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast ||
+ header.IsV4MulticastAddress(addr) ||
+ header.IsV6MulticastAddress(addr) ||
+ e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr)
+}
+
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- err := e.bindLocked(addr)
- if err != nil {
- return err
+ if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) {
+ return &tcpip.ErrBadLocalAddress{}
}
- e.BindNICID = addr.NIC
- e.BindAddr = addr.Addr
+ e.mu.Lock()
+ defer e.mu.Unlock()
- return nil
+ return e.bindLocked(addr)
}
// GetLocalAddress returns the address to which the endpoint is bound.
@@ -709,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.ident
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -721,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
- return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+ if addr, connected := e.net.GetRemoteAddress(); connected {
+ return addr, nil
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -754,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
- switch e.NetProto {
+ switch e.net.NetProto() {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
@@ -828,9 +705,9 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ defer e.mu.RUnlock()
+ ret := e.net.Info()
+ ret.ID.LocalPort = e.ident
return &ret
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index b8b839e4a..dfe453ff9 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,13 @@
package icmp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.thaw()
+
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- var err tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ info.ID, err = e.registerWithStack(info.NetProto, info.ID)
if err != nil {
- panic(err)
+ panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
}
-
- e.ID.LocalAddress = e.route.LocalAddress()
- } else if len(e.ID.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
- if err != nil {
- panic(err)
+ e.ident = info.ID.LocalPort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
index cc950cbde..729f50e9a 100644
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s
t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.AddRoute(tcpip.Route{
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
new file mode 100644
index 000000000..3818cb04e
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -0,0 +1,46 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "network",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/transport/icmp:__pkg__",
+ "//pkg/tcpip/transport/raw:__pkg__",
+ "//pkg/tcpip/transport/udp:__pkg__",
+ ],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ ],
+)
+
+go_test(
+ name = "network_test",
+ size = "small",
+ srcs = ["endpoint_test.go"],
+ deps = [
+ ":network",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
new file mode 100644
index 000000000..e3094f59f
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -0,0 +1,811 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package network provides facilities to support tcpip.Endpoints that operate
+// at the network layer or above.
+package network
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Endpoint is a datagram-based endpoint. It only supports sending datagrams to
+// a peer.
+//
+// +stateify savable
+type Endpoint struct {
+ // The following fields must only be set once then never changed.
+ stack *stack.Stack `state:"manual"`
+ ops *tcpip.SocketOptions
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ wasBound bool
+ // owner is the owner of transmitted packets.
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
+ multicastMemberships map[multicastMembership]struct{}
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ ttl uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastTTL uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastAddr tcpip.Address
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastNICID tcpip.NICID
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
+
+ // Lock ordering: mu > infoMu.
+ infoMu sync.RWMutex `state:"nosave"`
+ // info has a dedicated mutex so that we can avoid lock ordering violations
+ // when reading the endpoint's info. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling Info()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setInfo.
+ //
+ // +checklocks:infoMu
+ info stack.TransportEndpointInfo
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be accessed with atomics so that we can avoid lock ordering
+ // violations when reading the state. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling State()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setEndpointState.
+ //
+ // +checkatomics
+ state uint32
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+// Init initializes the endpoint.
+func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
+ }
+
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ *e = Endpoint{
+ stack: s,
+ ops: ops,
+ netProto: netProto,
+ transProto: transProto,
+
+ info: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
+ effectiveNetProto: netProto,
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+}
+
+// NetProto returns the network protocol the endpoint was initialized with.
+func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
+ return e.netProto
+}
+
+// setEndpointState sets the state of the endpoint.
+//
+// e.mu must be held to synchronize changes to state with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
+ atomic.StoreUint32(&e.state, uint32(state))
+}
+
+// State returns the state of the endpoint.
+func (e *Endpoint) State() transport.DatagramEndpointState {
+ return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+}
+
+// Close cleans the endpoint's resources and leaves the endpoint in a closed
+// state.
+func (e *Endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return
+ }
+
+ for mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ if e.connectedRoute != nil {
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+ }
+
+ e.setEndpointState(transport.DatagramEndpointStateClosed)
+}
+
+// SetOwner sets the owner of transmitted packets.
+func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.owner = owner
+}
+
+func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 {
+ if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
+ return multicastTTL
+ }
+
+ if ttl == 0 {
+ return route.DefaultTTL()
+ }
+
+ return ttl
+}
+
+// WriteContext holds the context for a write.
+type WriteContext struct {
+ transProto tcpip.TransportProtocolNumber
+ route *stack.Route
+ ttl uint8
+ tos uint8
+ owner tcpip.PacketOwner
+}
+
+// Release releases held resources.
+func (c *WriteContext) Release() {
+ c.route.Release()
+ *c = WriteContext{}
+}
+
+// WritePacketInfo is the properties of a packet that may be written.
+type WritePacketInfo struct {
+ NetProto tcpip.NetworkProtocolNumber
+ LocalAddress, RemoteAddress tcpip.Address
+ MaxHeaderLength uint16
+ RequiresTXTransportChecksum bool
+}
+
+// PacketInfo returns the properties of a packet that will be written.
+func (c *WriteContext) PacketInfo() WritePacketInfo {
+ return WritePacketInfo{
+ NetProto: c.route.NetProto(),
+ LocalAddress: c.route.LocalAddress(),
+ RemoteAddress: c.route.RemoteAddress(),
+ MaxHeaderLength: c.route.MaxHeaderLength(),
+ RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(),
+ }
+}
+
+// WritePacket attempts to write the packet.
+func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
+ pkt.Owner = c.owner
+
+ if headerIncluded {
+ return c.route.WriteHeaderIncludedPacket(pkt)
+ }
+
+ return c.route.WritePacket(stack.NetworkHeaderParams{
+ Protocol: c.transProto,
+ TTL: c.ttl,
+ TOS: c.tos,
+ }, pkt)
+}
+
+// AcquireContextForWrite acquires a WriteContext.
+func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
+ }
+
+ if e.writeShutdown {
+ return WriteContext{}, &tcpip.ErrClosedForSend{}
+ }
+
+ route := e.connectedRoute
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return WriteContext{}, &tcpip.ErrDestinationRequired{}
+ }
+
+ route.Acquire()
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicID := opts.To.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
+ info := e.Info()
+ if info.BindNICID != 0 {
+ if nicID != 0 && nicID != info.BindNICID {
+ return WriteContext{}, &tcpip.ErrNoRoute{}
+ }
+
+ nicID = info.BindNICID
+ }
+ if nicID == 0 {
+ nicID = info.RegisterNICID
+ }
+
+ dst, netProto, err := e.checkV4Mapped(*opts.To)
+ if err != nil {
+ return WriteContext{}, err
+ }
+
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
+ if err != nil {
+ return WriteContext{}, err
+ }
+ }
+
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
+ route.Release()
+ return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
+ }
+
+ var tos uint8
+ switch netProto := route.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ tos = e.ipv4TOS
+ case header.IPv6ProtocolNumber:
+ tos = e.ipv6TClass
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ return WriteContext{
+ transProto: e.transProto,
+ route: route,
+ ttl: calculateTTL(route, e.ttl, e.multicastTTL),
+ tos: tos,
+ owner: e.owner,
+ }, nil
+}
+
+// Disconnect disconnects the endpoint from its peer.
+func (e *Endpoint) Disconnect() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return
+ }
+
+ info := e.Info()
+ // Exclude ephemerally bound endpoints.
+ if e.wasBound {
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: info.BindAddr,
+ }
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ } else {
+ info.ID = stack.TransportEndpointID{}
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+ }
+ e.setInfo(info)
+
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+}
+
+// connectRouteRLocked establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+ localAddr := e.Info().ID.LocalAddress
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicID == 0 {
+ nicID = e.multicastNICID
+ }
+ if localAddr == "" && nicID == 0 {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
+ if err != nil {
+ return nil, 0, err
+ }
+ return r, nicID, nil
+}
+
+// Connect connects the endpoint to the address.
+func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ return nil
+ })
+}
+
+// ConnectAndThen connects the endpoint to the address and then calls the
+// provided function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the network protocol used to connect to the peer
+// and the source and destination addresses that will be used to send traffic to
+// the peer.
+func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ info := e.Info()
+ nicID := addr.NIC
+ switch e.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ if info.BindNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != info.BindNICID {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ nicID = info.BindNICID
+ default:
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4Mapped(addr)
+ if err != nil {
+ return err
+ }
+
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
+ if err != nil {
+ return err
+ }
+
+ id := stack.TransportEndpointID{
+ LocalAddress: info.ID.LocalAddress,
+ RemoteAddress: r.RemoteAddress(),
+ }
+ if e.State() == transport.DatagramEndpointStateInitial {
+ id.LocalAddress = r.LocalAddress()
+ }
+
+ if err := f(r.NetProto(), info.ID, id); err != nil {
+ return err
+ }
+
+ if e.connectedRoute != nil {
+ // If the endpoint was previously connected then release any previous route.
+ e.connectedRoute.Release()
+ }
+ e.connectedRoute = r
+ info.ID = id
+ info.RegisterNICID = nicID
+ e.setInfo(info)
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateConnected)
+ return nil
+}
+
+// Shutdown shutsdown the endpoint.
+func (e *Endpoint) Shutdown() tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ e.writeShutdown = true
+ return nil
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
+
+// checkV4MappedRLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ info := e.Info()
+ unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
+}
+
+// Bind binds the endpoint to the address.
+func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
+ return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error {
+ return nil
+ })
+}
+
+// BindAndThen binds the endpoint to the address and then calls the provided
+// function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the bound network protocol and address.
+func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.State() != transport.DatagramEndpointStateInitial {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4Mapped(addr)
+ if err != nil {
+ return err
+ }
+
+ nicID := addr.NIC
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
+ nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr)
+ if nicID == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if err := f(netProto, addr.Addr); err != nil {
+ return err
+ }
+
+ e.wasBound = true
+
+ info := e.Info()
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: addr.Addr,
+ }
+ info.BindNICID = addr.NIC
+ info.RegisterNICID = nicID
+ info.BindAddr = addr.Addr
+ e.setInfo(info)
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ return nil
+}
+
+// WasBound returns true iff the endpoint was ever bound.
+func (e *Endpoint) WasBound() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.wasBound
+}
+
+// GetLocalAddress returns the address that the endpoint is bound to.
+func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ info := e.Info()
+ addr := info.BindAddr
+ if e.State() == transport.DatagramEndpointStateConnected {
+ addr = e.connectedRoute.LocalAddress()
+ }
+
+ return tcpip.FullAddress{
+ NIC: info.RegisterNICID,
+ Addr: addr,
+ }
+}
+
+// GetRemoteAddress returns the address that the endpoint is connected to.
+func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return tcpip.FullAddress{}, false
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.connectedRoute.RemoteAddress(),
+ NIC: e.Info().RegisterNICID,
+ }, true
+}
+
+// SetSockOptInt sets the socket option.
+func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.ipv4TOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.ipv6TClass = uint8(v)
+ e.mu.Unlock()
+ }
+
+ return nil
+}
+
+// GetSockOptInt returns the socket option.
+func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.ipv4TOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.ipv6TClass)
+ e.mu.RUnlock()
+ return v, nil
+
+ default:
+ return -1, &tcpip.ErrUnknownProtocolOption{}
+ }
+}
+
+// SetSockOpt sets the socket option.
+func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ fa, netProto, err := e.checkV4Mapped(fa)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case *tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return &tcpip.ErrPortInUse{}
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToInsert] = struct{}{}
+
+ case *tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ delete(e.multicastMemberships, memToRemove)
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt returns the socket option.
+func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ NIC: e.multicastNICID,
+ InterfaceAddr: e.multicastAddr,
+ }
+ e.mu.Unlock()
+
+ default:
+ return &tcpip.ErrUnknownProtocolOption{}
+ }
+ return nil
+}
+
+// Info returns a copy of the endpoint info.
+func (e *Endpoint) Info() stack.TransportEndpointInfo {
+ e.infoMu.RLock()
+ defer e.infoMu.RUnlock()
+ return e.info
+}
+
+// setInfo sets the endpoint's info.
+//
+// e.mu must be held to synchronize changes to info with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
+ e.infoMu.Lock()
+ defer e.infoMu.Unlock()
+ e.info = info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
new file mode 100644
index 000000000..68bd1fbf6
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -0,0 +1,58 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *Endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.stack = s
+
+ for m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err))
+ }
+ }
+
+ info := e.Info()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound:
+ if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress))
+ }
+ }
+ case transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ multicastLoop := e.ops.GetMulticastLoop()
+ e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ if err != nil {
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
new file mode 100644
index 000000000..f263a9ea2
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -0,0 +1,318 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+)
+
+func TestEndpointStateTransitions(t *testing.T) {
+ const nicID = 1
+
+ data := buffer.View([]byte{1, 2, 4, 5})
+ v4Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(ipv4NICAddr),
+ checker.DstAddr(ipv4RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ v6Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(ipv6NICAddr),
+ checker.DstAddr(ipv6RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ expectedMaxHeaderLength uint16
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLocalAddr tcpip.Address
+ bindAddr tcpip.Address
+ expectedBoundAddr tcpip.Address
+ remoteAddr tcpip.Address
+ expectedRemoteAddr tcpip.Address
+ checker func(*testing.T, buffer.View)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: header.IPv4AllSystems,
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: ipv4RemoteAddr,
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv6FixedHeaderSize,
+ expectedNetProto: ipv6.ProtocolNumber,
+ expectedLocalAddr: ipv6NICAddr,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ expectedBoundAddr: header.IPv6AllNodesMulticastAddress,
+ remoteAddr: ipv6RemoteAddr,
+ expectedRemoteAddr: ipv6RemoteAddr,
+ checker: v6Checker,
+ },
+ {
+ name: "IPv4-mapped-IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: testutil.MustParse6("::ffff:e000:0001"),
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: testutil.MustParse6("::ffff:0607:0809"),
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ })
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateInitial {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateBound {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); connected {
+ t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr)
+ }
+
+ connectAddr := tcpip.FullAddress{Addr: test.remoteAddr}
+ if err := ep.Connect(connectAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", connectAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateConnected {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); !connected {
+ t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr)
+ } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" {
+ t.Errorf("remote address mismatch (-want +got):\n%s", diff)
+ }
+
+ ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("ep.AcquireContexForWrite({}): %s", err)
+ }
+ defer ctx.Release()
+ info := ctx.PacketInfo()
+ if diff := cmp.Diff(network.WritePacketInfo{
+ NetProto: test.expectedNetProto,
+ LocalAddress: test.expectedLocalAddr,
+ RemoteAddress: test.expectedRemoteAddr,
+ MaxHeaderLength: test.expectedMaxHeaderLength,
+ RequiresTXTransportChecksum: true,
+ }, info); diff != "" {
+ t.Errorf("write packet info mismatch (-want +got):\n%s", diff)
+ }
+ if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(info.MaxHeaderLength),
+ Data: data.ToVectorisedView(),
+ }), false /* headerIncluded */); err != nil {
+ t.Fatalf("ctx.WritePacket(_, false): %s", err)
+ }
+ if pkt, ok := e.Read(); !ok {
+ t.Fatalf("expected packet to be read from link endpoint")
+ } else {
+ test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ }
+
+ ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateClosed {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed)
+ }
+ })
+ }
+}
+
+func TestBindNICID(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ bindAddr tcpip.Address
+ unicast bool
+ }{
+ {
+ name: "IPv4 multicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: header.IPv4AllSystems,
+ unicast: false,
+ },
+ {
+ name: "IPv6 multicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ unicast: false,
+ },
+ {
+ name: "IPv4 unicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: ipv4NICAddr,
+ unicast: true,
+ },
+ {
+ name: "IPv6 unicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: ipv6NICAddr,
+ unicast: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, testBindNICID := range []tcpip.NICID{0, nicID} {
+ t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
+ }
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if ep.WasBound() {
+ t.Fatal("got ep.WasBound() = true, want = false")
+ }
+ wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if !ep.WasBound() {
+ t.Error("got ep.WasBound() = false, want = true")
+ }
+ wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
+ wantInfo.BindAddr = bindAddr.Addr
+ wantInfo.BindNICID = bindAddr.NIC
+ if test.unicast {
+ wantInfo.RegisterNICID = nicID
+ } else {
+ wantInfo.RegisterNICID = bindAddr.NIC
+ }
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 8e7bb6c6e..80eef39e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,7 +25,6 @@
package packet
import (
- "fmt"
"io"
"time"
@@ -60,52 +59,47 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
-
- // The following fields are used to manage the receive queue and are
- // protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats
+
+ // The following fields are used to manage the receive queue.
+ rcvMu sync.Mutex `state:"nosave"`
+ // +checklocks:rcvMu
+ rcvList packetList
+ // +checklocks:rcvMu
rcvBufSize int
- rcvClosed bool
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ // +checklocks:rcvMu
+ rcvClosed bool
+ // +checklocks:rcvMu
+ rcvDisabled bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ closed bool
+ // +checklocks:mu
+ boundNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
boundNIC tcpip.NICID
- // lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
- lastError tcpip.Error
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
- // frozen indicates if the packets should be delivered to the endpoint
- // during restore.
- frozen bool
+ // +checklocks:lastErrorMu
+ lastError tcpip.Error
}
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: s,
+ cooked: cooked,
+ boundNetProto: netProto,
+ waiterQueue: waiterQueue,
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
@@ -141,7 +135,7 @@ func (ep *endpoint) Close() {
return
}
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
ep.rcvMu.Lock()
defer ep.rcvMu.Unlock()
@@ -154,7 +148,6 @@ func (ep *endpoint) Close() {
}
ep.closed = true
- ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -189,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
Total: packet.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: packet.receivedAt.UnixNano(),
+ Timestamp: packet.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -207,8 +200,52 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
return res, nil
}
-func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
- return 0, &tcpip.ErrInvalidOptionValue{}
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ if !ep.stack.PacketEndpointWriteSupported() {
+ return 0, &tcpip.ErrNotSupported{}
+ }
+
+ ep.mu.Lock()
+ closed := ep.closed
+ nicID := ep.boundNIC
+ proto := ep.boundNetProto
+ ep.mu.Unlock()
+ if closed {
+ return 0, &tcpip.ErrClosedForSend{}
+ }
+
+ var remote tcpip.LinkAddress
+ if to := opts.To; to != nil {
+ remote = tcpip.LinkAddress(to.Addr)
+
+ if n := to.NIC; n != 0 {
+ nicID = n
+ }
+
+ if p := to.Port; p != 0 {
+ proto = tcpip.NetworkProtocolNumber(p)
+ }
+ }
+
+ if nicID == 0 {
+ return 0, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make(buffer.View, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
+
+ if err := func() tcpip.Error {
+ if ep.cooked {
+ return ep.stack.WritePacketToRemote(nicID, remote, proto, payloadBytes.ToVectorisedView())
+ }
+ return ep.stack.WriteRawPacket(nicID, proto, payloadBytes.ToVectorisedView())
+ }(); err != nil {
+ return 0, err
+ }
+ return int64(len(payloadBytes)), nil
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
@@ -253,29 +290,42 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
- // If the NIC being bound is the same then just return success.
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if netProto == 0 {
+ // Do not allow unbinding the network protocol.
+ netProto = ep.boundNetProto
+ }
+
+ if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto {
+ // Already bound to the requested NIC and network protocol.
return nil
}
- // Unregister endpoint with all the nics.
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
- ep.bound = false
+ // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new
+ // binding.
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
+ ep.boundNIC = 0
+ ep.boundNetProto = 0
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
- ep.bound = true
ep.boundNIC = addr.NIC
-
+ ep.boundNetProto = netProto
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: ep.boundNIC,
+ Port: uint16(ep.boundNetProto),
+ }, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -359,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -371,7 +421,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
rcvBufSize := ep.ops.GetReceiveBufferSize()
- if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
+ if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -380,76 +430,39 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
wasEmpty := ep.rcvBufSize == 0
- // Push new packet into receive list and increment the buffer size.
- var packet packet
+ rcvdPkt := packet{
+ packetInfo: tcpip.LinkPacketInfo{
+ Protocol: netProto,
+ PktType: pkt.PktType,
+ },
+ senderAddr: tcpip.FullAddress{
+ NIC: nicID,
+ },
+ receivedAt: ep.stack.Clock().Now(),
+ }
+
if !pkt.LinkHeader().View().IsEmpty() {
- // Get info directly from the ethernet header.
hdr := header.Ethernet(pkt.LinkHeader().View())
- packet.senderAddr = tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.Address(hdr.SourceAddress()),
- }
- packet.packetInfo.Protocol = netProto
- packet.packetInfo.PktType = pkt.PktType
- } else {
- // Guess the would-be ethernet header.
- packet.senderAddr = tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.Address(localAddr),
- }
- packet.packetInfo.Protocol = netProto
- packet.packetInfo.PktType = pkt.PktType
+ rcvdPkt.senderAddr.Addr = tcpip.Address(hdr.SourceAddress())
}
if ep.cooked {
- // Cooked packets can simply be queued.
- switch pkt.PktType {
- case tcpip.PacketHost:
- packet.data = pkt.Data().ExtractVV()
- case tcpip.PacketOutgoing:
- // Strip Link Header.
- var combinedVV buffer.VectorisedView
- if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
- combinedVV.AppendView(v)
- }
- if v := pkt.TransportHeader().View(); !v.IsEmpty() {
- combinedVV.AppendView(v)
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- default:
- panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ // Cooked packet endpoints don't include the link-headers in received
+ // packets.
+ if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
+ rcvdPkt.data.AppendView(v)
}
- } else {
- // Raw packets need their ethernet headers prepended before
- // queueing.
- var linkHeader buffer.View
- if pkt.PktType != tcpip.PacketOutgoing {
- if pkt.LinkHeader().View().IsEmpty() {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
- }
- combinedVV := linkHeader.ToVectorisedView()
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- } else {
- packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ if v := pkt.TransportHeader().View(); !v.IsEmpty() {
+ rcvdPkt.data.AppendView(v)
}
+ rcvdPkt.data.Append(pkt.Data().ExtractVV())
+ } else {
+ // Raw packet endpoints include link-headers in received packets.
+ rcvdPkt.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
- packet.receivedAt = ep.stack.Clock().Now()
- ep.rcvList.PushBack(&packet)
- ep.rcvBufSize += packet.data.Size()
+ ep.rcvList.PushBack(&rcvdPkt)
+ ep.rcvBufSize += rcvdPkt.data.Size()
ep.rcvMu.Unlock()
ep.stats.PacketsReceived.Increment()
@@ -467,10 +480,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto}
}
// Stats returns a pointer to the endpoint stats.
@@ -485,18 +496,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
-
-// freeze prevents any more packets from being delivered to the endpoint.
-func (ep *endpoint) freeze() {
- ep.mu.Lock()
- ep.frozen = true
- ep.mu.Unlock()
-}
-
-// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
-// new packets to be delivered again.
-func (ep *endpoint) thaw() {
- ep.mu.Lock()
- ep.frozen = false
- ep.mu.Unlock()
-}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index e729921db..88cd80ad3 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,33 +35,34 @@ func (p *packet) loadReceivedAt(nsec int64) {
// saveData saves packet.data field.
func (p *packet) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads packet.data field.
func (p *packet) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- ep.freeze()
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+ ep.rcvDisabled = true
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- ep.thaw()
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
- if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(err)
+ if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
+ panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
+
+ ep.rcvMu.Lock()
+ ep.rcvDisabled = false
+ ep.rcvMu.Unlock()
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2eab09088..b7e97e218 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b3d8951ff..181b478d0 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,7 @@
package raw
import (
+ "fmt"
"io"
"time"
@@ -34,6 +35,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -57,15 +60,19 @@ type rawPacket struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
+
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
@@ -74,20 +81,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- connected bool
- bound bool
- // route is the route to a remote network endpoint. It is set via
- // Connect(), and is valid only when conneted is true.
- route *stack.Route `state:"manual"`
- stats tcpip.TransportEndpointStats `state:"nosave"`
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
+ mu sync.RWMutex `state:"nosave"`
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
@@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
- return nil, &tcpip.ErrUnknownProtocol{}
- }
-
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
associated: associated,
}
@@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, transProto, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -137,7 +125,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return nil, err
}
@@ -154,11 +142,17 @@ func (e *endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.closed || !e.associated {
+ if e.net.State() == transport.DatagramEndpointStateClosed {
return
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
+ e.net.Close()
+
+ if !e.associated {
+ return
+ }
+
+ e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e)
e.rcvMu.Lock()
defer e.rcvMu.Unlock()
@@ -170,15 +164,6 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- e.connected = false
-
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- e.closed = true
-
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -186,9 +171,7 @@ func (e *endpoint) Close() {
func (*endpoint) ModerateRecvBuf(int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.mu.Lock()
- defer e.mu.Unlock()
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -219,7 +202,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: pkt.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: pkt.receivedAt.UnixNano(),
+ Timestamp: pkt.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ netProto := e.net.NetProto()
// We can create, but not write to, unassociated IPv6 endpoints.
- if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ if !e.associated && netProto == header.IPv6ProtocolNumber {
return 0, &tcpip.ErrInvalidOptionValue{}
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
return 0, &tcpip.ErrInvalidOptionValue{}
}
}
@@ -269,79 +253,25 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return 0, err
}
- payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- if e.closed {
- return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
- }
-
- payloadBytes := make([]byte, p.Len())
- if _, err := io.ReadFull(p, payloadBytes); err != nil {
- return nil, nil, nil, &tcpip.ErrBadBuffer{}
- }
-
- // Did the user caller provide a destination? If not, use the connected
- // destination.
- if opts.To == nil {
- // If the user doesn't specify a destination, they should have
- // connected to another address.
- if !e.connected {
- return nil, nil, nil, &tcpip.ErrDestinationRequired{}
- }
- e.route.Acquire()
-
- return payloadBytes, e.route, e.owner, nil
- }
-
- // The caller provided a destination. Reject destination address if it
- // goes through a different NIC than the endpoint was bound to.
- nic := opts.To.NIC
- if e.bound && nic != 0 && nic != e.BindNICID {
- return nil, nil, nil, &tcpip.ErrNoRoute{}
- }
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
- // Find the route to the destination. If BindAddress is 0,
- // FindRoute will choose an appropriate source address.
- route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
- if err != nil {
- return nil, nil, nil, err
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
- return payloadBytes, route, e.owner, nil
- }()
- if err != nil {
+ if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil {
return 0, err
}
- defer route.Release()
-
- if e.ops.GetHeaderIncluded() {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
- return 0, err
- }
- } else {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(route.MaxHeaderLength()),
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- pkt.Owner = owner
- if err := route.WritePacket(stack.NetworkHeaderParams{
- Protocol: e.TransProto,
- TTL: route.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- return 0, err
- }
- }
return int64(len(payloadBytes)), nil
}
@@ -353,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ netProto := e.net.NetProto()
+
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
return &tcpip.ErrAddressFamilyNotSupported{}
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.closed {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- nic := addr.NIC
- if e.bound {
- if e.BindNICID == 0 {
- // If we're bound, but not to a specific NIC, the NIC
- // in addr will be used. Nothing to do here.
- } else if addr.NIC == 0 {
- // If we're bound to a specific NIC, but addr doesn't
- // specify a NIC, use the bound NIC.
- nic = e.BindNICID
- } else if addr.NIC != e.BindNICID {
- // We're bound and addr specifies a NIC. They must be
- // the same.
- return &tcpip.ErrInvalidEndpointState{}
- }
- }
-
- // Find a route to the destination.
- route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false)
- if err != nil {
- return err
- }
-
- if e.associated {
- // Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- route.Release()
- return err
+ return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ if e.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ return err
+ }
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = nic
- }
-
- if e.route != nil {
- // If the endpoint was previously connected then release any previous route.
- e.route.Release()
- }
- e.route = route
- e.connected = true
- return nil
+ return nil
+ })
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if !e.connected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return &tcpip.ErrNotConnected{}
}
return nil
@@ -430,46 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // If a local address was specified, verify that it's valid.
- if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
+ return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error {
+ if !e.associated {
+ return nil
+ }
- if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = addr.NIC
- e.BindNICID = addr.NIC
- }
-
- e.BindAddr = addr.Addr
- e.bound = true
-
- return nil
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
+ return nil
+ })
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- addr := e.BindAddr
- if e.connected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- // Linux returns the protocol in the port field.
- Port: uint16(e.TransProto),
- }, nil
+ a := e.net.GetLocalAddress()
+ // Linux returns the protocol in the port field.
+ a.Port = uint16(e.transProto)
+ return a, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -502,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
default:
- return &tcpip.ErrUnknownProtocolOption{}
+ return e.net.SetSockOpt(opt)
}
}
-func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ return e.net.SetSockOptInt(opt, v)
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -529,100 +402,108 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
return v, nil
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- e.mu.RLock()
- e.rcvMu.Lock()
+ notifyReadableEvents := func() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return false
+ }
- // Drop the packet if our buffer is currently full or if this is an unassociated
- // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
- // See: https://man7.org/linux/man-pages/man7/raw.7.html
- //
- // An IPPROTO_RAW socket is send only. If you really want to receive
- // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
- // Note that packet sockets don't reassemble IP fragments, unlike raw
- // sockets.
- if e.rcvClosed || !e.associated {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ClosedReceiver.Increment()
- return
- }
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return false
+ }
- rcvBufSize := e.ops.GetReceiveBufferSize()
- if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
- return
- }
+ srcAddr := pkt.Network().SourceAddress()
+ info := e.net.Info()
- remoteAddr := pkt.Network().SourceAddress()
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if info.ID.RemoteAddress != srcAddr {
+ return false
+ }
- if e.bound {
- // If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
- // If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
+ // Connected sockets may also have been bound to a specific
+ // address/NIC.
+ fallthrough
+ case transport.DatagramEndpointStateBound:
+ // If bound to a NIC, only accept data for that NIC.
+ if info.BindNICID != 0 && info.BindNICID != pkt.NICID {
+ return false
+ }
+
+ // If bound to an address, only accept data for that address.
+ if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ return false
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- }
- // If connected, only accept packets from the remote address we
- // connected to.
- if e.connected && e.route.RemoteAddress() != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
+ wasEmpty := e.rcvBufSize == 0
- wasEmpty := e.rcvBufSize == 0
+ // Push new packet into receive list and increment the buffer size.
+ packet := &rawPacket{
+ senderAddr: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: srcAddr,
+ },
+ }
- // Push new packet into receive list and increment the buffer size.
- packet := &rawPacket{
- senderAddr: tcpip.FullAddress{
- NIC: pkt.NICID,
- Addr: remoteAddr,
- },
- }
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
+ // overlapping slices.
+ var combinedVV buffer.VectorisedView
+ if info.NetProto == header.IPv4ProtocolNumber {
+ networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader))
+ headers = append(headers, networkHeader...)
+ headers = append(headers, transportHeader...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
+ combinedVV.Append(pkt.Data().ExtractVV())
+ packet.data = combinedVV
+ packet.receivedAt = e.stack.Clock().Now()
- // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
- // We copy headers' underlying bytes because pkt.*Header may point to
- // the middle of a slice, and another struct may point to the "outer"
- // slice. Save/restore doesn't support overlapping slices and will fail.
- var combinedVV buffer.VectorisedView
- if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
- combinedVV = headers.ToVectorisedView()
- } else {
- combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- packet.receivedAt = e.stack.Clock().Now()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.stats.PacketsReceived.Increment()
- e.rcvList.PushBack(packet)
- e.rcvBufSize += packet.data.Size()
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stats.PacketsReceived.Increment()
- // Notify waiters that there's data to be read.
- if wasEmpty {
+ // Notify waiters that there is data to be read now.
+ return wasEmpty
+ }()
+
+ if notifyReadableEvents {
e.waiterQueue.Notify(waiter.ReadableEvents)
}
}
@@ -634,10 +515,7 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ ret := e.net.Info()
return &ret
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 39669b445..e74713064 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,6 +15,7 @@
package raw
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.net.Resume(s)
+
e.thaw()
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // If the endpoint is connected, re-connect.
- if e.connected {
- var err tcpip.Error
- // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
- // remote address. We used to pass e.remote.RemoteAddress which was
- // effectively the empty address but since moving e.route to hold a pointer
- // to a route instead of the route by value, we pass the empty address
- // directly. Obviously this was always wrong since we should provide the
- // remote address we were connected to, to properly restore the route.
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
- if err != nil {
- panic(err)
- }
- }
-
- // If the endpoint is bound, re-bind.
- if e.bound {
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
if e.associated {
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- panic(err)
+ netProto := e.net.NetProto()
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
}
}
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 8436d2cf0..20958d882 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -68,6 +68,7 @@ go_library(
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
@@ -79,9 +80,10 @@ go_library(
go_test(
name = "tcp_x_test",
- size = "medium",
+ size = "large",
srcs = [
"dual_stack_test.go",
+ "rcv_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
"tcp_rack_test.go",
@@ -96,6 +98,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
@@ -112,16 +115,6 @@ go_test(
)
go_test(
- name = "rcv_test",
- size = "small",
- srcs = ["rcv_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
name = "tcp_test",
size = "small",
srcs = [
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index aa413ad05..caf14b0dc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -15,12 +15,12 @@
package tcp
import (
+ "container/list"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 {
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
+ stack *stack.Stack
+ protocol *protocol
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
@@ -99,18 +100,6 @@ type listenContext struct {
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
netProto tcpip.NetworkProtocolNumber
-
- // pendingMu protects pendingEndpoints. This should only be accessed
- // by the listening endpoint's worker goroutine.
- //
- // Lock Ordering: listenEP.workerMu -> pendingMu
- pendingMu sync.Mutex
- // pending is used to wait for all pendingEndpoints to finish when
- // a socket is closed.
- pending sync.WaitGroup
- // pendingEndpoints is a map of all endpoints for which a handshake is
- // in progress.
- pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -119,15 +108,15 @@ func timeStamp(clock tcpip.Clock) uint32 {
}
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stk,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6Only: v6Only,
- netProto: netProto,
- listenEP: listenEP,
- pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
+ stack: stk,
+ protocol: protocol,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6Only: v6Only,
+ netProto: netProto,
+ listenEP: listenEP,
}
for i := range l.nonce {
@@ -191,17 +180,9 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
-func (l *listenContext) useSynCookies() bool {
- var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
- if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
- panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
- }
- return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
-}
-
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -213,7 +194,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
return nil, err
}
- n := newEndpoint(l.stack, netProto, queue)
+ n := newEndpoint(l.stack, l.protocol, netProto, queue)
n.ops.SetV6Only(l.v6Only)
n.TransportEndpointInfo.ID = s.id
n.boundNICID = s.nicID
@@ -244,10 +225,10 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
// On success, a handshake h is returned with h.ep.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
+func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed())
+ isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret)
ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
@@ -271,18 +252,15 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
return nil, &tcpip.ErrConnectionAborted{}
}
- l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
- l.listenEP.propagateInheritableOptionsLocked(ep)
+ l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce
if !ep.reserveTupleLocked() {
ep.mu.Unlock()
ep.Close()
- l.removePendingEndpoint(ep)
-
return nil, &tcpip.ErrConnectionAborted{}
}
@@ -301,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
ep.drainClosingSegmentQueue()
return nil, err
@@ -323,7 +297,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// established endpoint is returned with e.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
+func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
h, err := l.startHandshake(s, opts, queue, owner)
if err != nil {
return nil, err
@@ -342,39 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions,
return ep, nil
}
-func (l *listenContext) addPendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- l.pendingEndpoints[n.TransportEndpointInfo.ID] = n
- l.pending.Add(1)
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) removePendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.TransportEndpointInfo.ID)
- l.pending.Done()
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) closeAllPendingEndpoints() {
- l.pendingMu.Lock()
- for _, n := range l.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- l.pendingMu.Unlock()
- l.pending.Wait()
-}
-
-// Precondition: h.ep.mu must be held.
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.drainClosingSegmentQueue()
e.h = nil
}
@@ -382,12 +329,9 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// cleanupCompletedHandshake transfers any state from the completed handshake to
// the new endpoint.
//
-// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@@ -399,47 +343,11 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e.h = nil
}
-// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// listener has transitioned out of the listen state (accepted is the zero
-// value), the new endpoint is reset instead.
-func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
- // Drop the lock before notifying to avoid deadlock in user-specified
- // callbacks.
- delivered := func() bool {
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
- for {
- if e.accepted == (accepted{}) {
- return false
- }
- if e.accepted.endpoints.Len() == e.accepted.cap {
- e.acceptCond.Wait()
- continue
- }
-
- e.accepted.endpoints.PushBack(n)
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
- }
- return true
- }
- }()
- if delivered {
- e.waiterQueue.Notify(waiter.ReadableEvents)
- } else {
- n.notifyProtocolGoroutine(notifyReset)
- }
-}
-
// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
//
-// Precondition: e.mu and n.mu must be held.
+// +checklocks:e.mu
+// +checklocks:n.mu
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
n.portFlags = e.portFlags
@@ -450,9 +358,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// reserveTupleLocked reserves an accepted endpoint's tuple.
//
-// Preconditions:
-// * propagateInheritableOptionsLocked has been called.
-// * e.mu is held.
+// Precondition: e.propagateInheritableOptionsLocked has been called.
+//
+// +checklocks:e.mu
func (e *endpoint) reserveTupleLocked() bool {
dest := tcpip.FullAddress{
Addr: e.TransportEndpointInfo.ID.RemoteAddress,
@@ -487,70 +395,36 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// handleSynSegment is called in its own goroutine once the listening endpoint
-// receives a SYN segment. It is responsible for completing the handshake and
-// queueing the new endpoint for acceptance.
-//
-// A limited number of these goroutines are allowed before TCP starts using SYN
-// cookies to accept connections.
-//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error {
- defer s.decRef()
-
- h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
- if err != nil {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- atomic.AddInt32(&e.synRcvdCount, -1)
- return err
- }
+func (e *endpoint) acceptQueueIsFull() bool {
+ e.acceptMu.Lock()
+ full := e.acceptQueue.isFull()
+ e.acceptMu.Unlock()
+ return full
+}
- go func() {
- // Note that startHandshake returns a locked endpoint. The
- // force call here just makes it so.
- if err := h.complete(); err != nil { // +checklocksforce
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- ctx.cleanupFailedHandshake(h)
- atomic.AddInt32(&e.synRcvdCount, -1)
- return
- }
- ctx.cleanupCompletedHandshake(h)
- h.ep.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep, false /*withSynCookie*/)
- }()
+// +stateify savable
+type acceptQueue struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
- return nil
-}
+ // pendingEndpoints is a set of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[*endpoint]struct{}
-func (e *endpoint) synRcvdBacklogFull() bool {
- e.acceptMu.Lock()
- acceptedCap := e.accepted.cap
- e.acceptMu.Unlock()
- // The capacity of the accepted queue would always be one greater than the
- // listen backlog. But, the SYNRCVD connections count is always checked
- // against the listen backlog value for Linux parity reason.
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- //
- // We maintain an equality check here as the synRcvdCount is incremented
- // and compared only from a single listener context and the capacity of
- // the accepted queue can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
+ // capacity is the maximum number of endpoints that can be in endpoints.
+ capacity int
}
-func (e *endpoint) acceptQueueIsFull() bool {
- e.acceptMu.Lock()
- full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
- e.acceptMu.Unlock()
- return full
+func (a *acceptQueue) isFull() bool {
+ return a.endpoints.Len() == a.capacity
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+// +checklocks:e.mu
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
e.rcvQueueInfo.rcvQueueMu.Lock()
rcvClosed := e.rcvQueueInfo.RcvClosed
@@ -578,11 +452,95 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
opts := parseSynSegmentOptions(s)
- if !ctx.useSynCookies() {
- s.incRef()
- atomic.AddInt32(&e.synRcvdCount, 1)
- return e.handleSynSegment(ctx, s, &opts)
+
+ useSynCookies, err := func() (bool, tcpip.Error) {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ if alwaysUseSynCookies {
+ return true, nil
+ }
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 {
+ return true, nil
+ }
+
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return false, err
+ }
+
+ e.acceptQueue.pendingEndpoints[h.ep] = struct{}{}
+ e.pendingAccepted.Add(1)
+
+ go func() {
+ defer func() {
+ e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ delete(e.acceptQueue.pendingEndpoints, h.ep)
+ }()
+
+ // Note that startHandshake returns a locked endpoint. The force call
+ // here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ // Deliver the endpoint to the accept queue.
+ //
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ // The listener is transitioning out of the Listen state; bail.
+ if e.acceptQueue.capacity == 0 {
+ return false
+ }
+ if e.acceptQueue.isFull() {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.acceptQueue.endpoints.PushBack(h.ep)
+ return true
+ }
+ }()
+
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ h.ep.notifyProtocolGoroutine(notifyReset)
+ }
+ }()
+
+ return false, nil
+ }()
+ if err != nil {
+ return err
}
+ if !useSynCookies {
+ return nil
+ }
+
route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
@@ -600,10 +558,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
+ if opts.TS {
+ offset := e.protocol.tsOffset(s.dstAddr, s.srcAddr)
+ now := e.stack.Clock().NowMonotonic()
+ synOpts.TSVal = offset.TSVal(now)
+ }
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
fields := tcpFields{
id: s.id,
@@ -621,18 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
case s.flags.Contains(header.TCPFlagAck):
- if e.acceptQueueIsFull() {
- // Silently drop the ack as the application can't accept
- // the connection at this point. The ack will be
- // retransmitted by the sender anyway and we can
- // complete the connection at the time of retransmit if
- // the backlog has space.
- e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -668,9 +618,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// ACK was received from the sender.
return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
+
+ // Keep hold of acceptMu until the new endpoint is in the accept queue (or
+ // if there is an error), to guarantee that we will keep our spot in the
+ // queue even if another handshake from the syn queue completes.
+ e.acceptMu.Lock()
+ if e.acceptQueue.isFull() {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.acceptMu.Unlock()
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
- rcvdSynOptions := &header.TCPSynOptions{
+ rcvdSynOptions := header.TCPSynOptions{
MSS: mssTable[data],
// Disable Window scaling as original SYN is
// lost.
@@ -689,6 +657,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
+ e.acceptMu.Unlock()
return err
}
@@ -700,6 +669,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !n.reserveTupleLocked() {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -717,6 +687,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.boundBindToDevice,
); err != nil {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -725,25 +696,22 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
n.isRegistered = true
-
- // clear the tsOffset for the newly created
- // endpoint as the Timestamp was already
- // randomly offset when the original SYN-ACK was
- // sent above.
- n.TSOffset = 0
+ n.TSOffset = n.protocol.tsOffset(s.dstAddr, s.srcAddr)
// Switch state to connected.
n.isConnectNotified = true
- n.transitionToStateEstablishedLocked(&handshake{
- ep: n,
- iss: iss,
- ackNum: irs + 1,
- rcvWnd: seqnum.Size(n.initialReceiveWindow()),
- sndWnd: s.window,
- rcvWndScale: e.rcvWndScaleForHandshake(),
- sndWndScale: rcvdSynOptions.WS,
- mss: rcvdSynOptions.MSS,
- })
+ h := handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ sampleRTTWithTSOnly: true,
+ }
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -752,20 +720,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.newSegmentWaker.Assert()
}
- // Do the delivery in a separate goroutine so
- // that we don't block the listen loop in case
- // the application is slow to accept or stops
- // accepting.
- //
- // NOTE: This won't result in an unbounded
- // number of goroutines as we do check before
- // entering here that there was at least some
- // space available in the backlog.
-
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n, true /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.acceptQueue.endpoints.PushBack(n)
+ e.acceptMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.ReadableEvents)
return nil
default:
@@ -779,17 +742,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.ops.GetV6Only()
- ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
+ ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
- // Mark endpoint as closed. This will prevent goroutines running
- // handleSynSegment() from attempting to queue new connections
- // to the endpoint.
e.setEndpointState(StateClose)
- // Close any endpoints in SYN-RCVD state.
- ctx.closeAllPendingEndpoints()
-
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 93ed161f9..80cd07218 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -30,6 +30,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// InitialRTO is the initial retransmission timeout.
+// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142
+const InitialRTO = time.Second
+
// maxSegmentsPerWake is the maximum number of segments to process in the main
// protocol goroutine per wake-up. Yielding [after this number of segments are
// processed] allows other events to be processed as well (e.g., timeouts,
@@ -105,6 +109,11 @@ type handshake struct {
// sendSYNOpts is the cached values for the SYN options to be sent.
sendSYNOpts header.TCPSynOptions
+
+ // sampleRTTWithTSOnly is true when the segment was retransmitted or we can't
+ // tell; then RTT can only be sampled when the incoming segment has timestamp
+ // options enabled.
+ sampleRTTWithTSOnly bool
}
func (e *endpoint) newHandshake() *handshake {
@@ -117,10 +126,12 @@ func (e *endpoint) newHandshake() *handshake {
h.resetState()
// Store reference to handshake state in endpoint.
e.h = h
+ // By the time handshake is created, e.ID is already initialized.
+ e.TSOffset = e.protocol.tsOffset(e.ID.LocalAddress, e.ID.RemoteAddress)
return h
}
-func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) *handshake {
h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
@@ -150,20 +161,23 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret)
}
// generateSecureISN generates a secure Initial Sequence number based on the
// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value {
isnHasher := jenkins.Sum32(seed)
- isnHasher.Write([]byte(id.LocalAddress))
- isnHasher.Write([]byte(id.RemoteAddress))
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = isnHasher.Write([]byte(id.LocalAddress))
+ _, _ = isnHasher.Write([]byte(id.RemoteAddress))
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
// The time period here is 64ns. This is similar to what linux uses
// generate a sequence number that overlaps less than one
// time per MSL (2 minutes).
@@ -190,7 +204,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -251,10 +265,10 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
rcvSynOpts := parseSynSegmentOptions(s)
// Remember if the Timestamp option was negotiated.
- h.ep.maybeEnableTimestamp(&rcvSynOpts)
+ h.ep.maybeEnableTimestamp(rcvSynOpts)
// Remember if the SACKPermitted option was negotiated.
- h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+ h.ep.maybeEnableSACKPermitted(rcvSynOpts)
// Remember the sequence we'll ack from now on.
h.ackNum = s.sequenceNumber + 1
@@ -266,8 +280,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// and the handshake is completed.
if s.flags.Contains(header.TCPFlagAck) {
h.state = handshakeCompleted
-
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -283,7 +296,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
@@ -353,7 +366,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: h.ep.SendTSOk,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.SACKPermitted,
MSS: h.ep.amss,
@@ -402,9 +415,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
if h.ep.SendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
+
h.state = handshakeCompleted
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -480,7 +494,7 @@ func (h *handshake) start() {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
@@ -522,7 +536,7 @@ func (h *handshake) complete() tcpip.Error {
defer s.Done()
// Initialize the resend timer.
- timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert)
+ timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert)
if err != nil {
return err
}
@@ -557,6 +571,10 @@ func (h *handshake) complete() tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, h.sendSYNOpts)
+ // If we have ever retransmitted the SYN-ACK or
+ // SYN segment, we should only measure RTT if
+ // TS option is present.
+ h.sampleRTTWithTSOnly = true
}
case wakerForNotification:
@@ -564,6 +582,9 @@ func (h *handshake) complete() tcpip.Error {
if (n&notifyClose)|(n&notifyAbort) != 0 {
return &tcpip.ErrAborted{}
}
+ if n&notifyShutdown != 0 {
+ return &tcpip.ErrConnectionReset{}
+ }
if n&notifyDrain != 0 {
for !h.ep.segmentQueue.empty() {
s := h.ep.segmentQueue.dequeue()
@@ -600,6 +621,40 @@ func (h *handshake) complete() tcpip.Error {
return nil
}
+// transitionToStateEstablisedLocked transitions the endpoint of the handshake
+// to an established state given the last segment received from peer. It also
+// initializes sender/receiver.
+func (h *handshake) transitionToStateEstablishedLocked(s *segment) {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ h.ep.snd = newSender(h.ep, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ now := h.ep.stack.Clock().NowMonotonic()
+
+ var rtt time.Duration
+ if h.ep.SendTSOk && s.parsedOptions.TSEcr != 0 {
+ rtt = h.ep.elapsed(now, s.parsedOptions.TSEcr)
+ }
+ if !h.sampleRTTWithTSOnly && rtt == 0 {
+ rtt = now.Sub(h.startTime)
+ }
+
+ if rtt > 0 {
+ h.ep.snd.updateRTO(rtt)
+ }
+
+ h.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ h.ep.rcv = newReceiver(h.ep, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ h.ep.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
+ h.ep.rcvQueueInfo.rcvQueueMu.Unlock()
+
+ h.ep.setEndpointState(StateEstablished)
+}
+
type backoffTimer struct {
timeout time.Duration
maxTimeout time.Duration
@@ -873,7 +928,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
+ offset += header.EncodeTSOption(e.tsValNow(), e.recentTimestamp(), options[offset:])
}
if e.SACKPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -965,26 +1020,6 @@ func (e *endpoint) completeWorkerLocked() {
}
}
-// transitionToStateEstablisedLocked transitions a given endpoint
-// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver.
-func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
- e.rcvQueueInfo.rcvQueueMu.Unlock()
-
- e.setEndpointState(StateEstablished)
-}
-
// transitionToStateCloseLocked ensures that the endpoint is
// cleaned up from the transport demuxer, "before" moving to
// StateClose. This will ensure that no packet will be
@@ -1286,7 +1321,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoopDone is called at the end of protocolMainLoop.
// +checklocksrelease:e.mu
-func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) {
+func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) {
if e.snd != nil {
e.snd.resendTimer.cleanup()
e.snd.probeTimer.cleanup()
@@ -1314,7 +1349,7 @@ func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *slee
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) {
var (
closeTimer tcpip.Timer
closeWaker sleep.Waker
@@ -1331,8 +1366,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.hardError = err
e.workerCleanup = true
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return err
+ e.protocolMainLoopDone(closeTimer)
+ return
}
}
@@ -1559,8 +1594,8 @@ loop:
// just want to terminate the loop and cleanup the
// endpoint.
cleanupOnError(nil)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
case StateTimeWait:
fallthrough
case StateClose:
@@ -1568,8 +1603,8 @@ loop:
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
}
}
}
@@ -1592,21 +1627,19 @@ loop:
// Handle any StateError transition from StateTimeWait.
if e.EndpointState() == StateError {
cleanupOnError(nil)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
}
e.transitionToStateCloseLocked()
- e.protocolMainLoopDone(closeTimer, &closeWaker)
+ e.protocolMainLoopDone(closeTimer)
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
reuseTW()
}
-
- return nil
}
// handleTimeWaitSegments processes segments received during TIME_WAIT
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 044123185..6a798e980 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,12 +15,10 @@
package tcp
import (
- "container/list"
"encoding/binary"
"fmt"
"io"
"math"
- "math/rand"
"runtime"
"strings"
"sync/atomic"
@@ -188,6 +186,8 @@ const (
// say TIME_WAIT.
notifyTickleWorker
notifyError
+ // notifyShutdown means that a connecting socket was shutdown.
+ notifyShutdown
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -204,6 +204,8 @@ type SACKInfo struct {
}
// ReceiveErrors collect segment receive errors within transport layer.
+//
+// +stateify savable
type ReceiveErrors struct {
tcpip.ReceiveErrors
@@ -233,6 +235,8 @@ type ReceiveErrors struct {
}
// SendErrors collect segment send errors within the transport layer.
+//
+// +stateify savable
type SendErrors struct {
tcpip.SendErrors
@@ -256,6 +260,8 @@ type SendErrors struct {
}
// Stats holds statistics about the endpoint.
+//
+// +stateify savable
type Stats struct {
// SegmentsReceived is the number of TCP segments received that
// the transport layer successfully parsed.
@@ -310,15 +316,6 @@ type rcvQueueInfo struct {
rcvQueue segmentList `state:"wait"`
}
-// +stateify savable
-type accepted struct {
- // NB: this could be an endpointList, but ilist only permits endpoints to
- // belong to one list at a time, and endpoints are already stored in the
- // dispatcher's list.
- endpoints list.List `state:".([]*endpoint)"`
- cap int
-}
-
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -334,7 +331,7 @@ type accepted struct {
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects accepted.
+// e.acceptMu -> Protects e.acceptQueue.
// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -378,6 +375,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
+ protocol *protocol `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
@@ -497,10 +495,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
- // synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state; this is only accessed atomically.
- synRcvdCount int32
-
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16
@@ -573,7 +567,8 @@ type endpoint struct {
// accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- accepted accepted
+ // +checklocks:acceptMu
+ acceptQueue acceptQueue
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -606,8 +601,7 @@ type endpoint struct {
gso stack.GSO
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats Stats `state:"nosave"`
+ stats Stats
// tcpLingerTimeout is the maximum amount of a time a socket
// a socket stays in TIME_WAIT state before being marked
@@ -803,9 +797,10 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
+ stack: s,
+ protocol: protocol,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: header.TCPProtocolNumber,
@@ -874,7 +869,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset(e.stack.Rand())
+
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
@@ -903,7 +898,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if e.accepted.endpoints.Len() != 0 {
+ if e.acceptQueue.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -1086,20 +1081,20 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- acceptedCopy := e.accepted
- e.accepted = accepted{}
- e.acceptMu.Unlock()
-
- if acceptedCopy == (accepted{}) {
- return
+ // Close any endpoints in SYN-RCVD state.
+ for n := range e.acceptQueue.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
}
-
- e.acceptCond.Broadcast()
-
+ e.acceptQueue.pendingEndpoints = nil
// Reset all connections that are waiting to be accepted.
- for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() {
n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
+ e.acceptQueue.endpoints.Init()
+ e.acceptMu.Unlock()
+
+ e.acceptCond.Broadcast()
+
// Wait for reset of all endpoints that are still waiting to be delivered to
// the now closed accepted.
e.pendingAccepted.Wait()
@@ -1717,6 +1712,27 @@ func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) {
return rcvBufSz
}
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *endpoint) OnSetSendBufferSize(sz int64) int64 {
+ atomic.StoreUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled, 1)
+ return sz
+}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *endpoint) WakeupWriters() {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ sendBufferSize := e.getSendBufferSize()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ notify := (sendBufferSize - e.sndQueueInfo.SndBufUsed) >= e.sndQueueInfo.SndBufUsed>>1
+ e.sndQueueInfo.sndQueueMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.WritableEvents)
+ }
+}
+
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
@@ -2038,7 +2054,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err
@@ -2177,7 +2193,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h := jenkins.Sum32(e.stack.Seed())
+ h := jenkins.Sum32(e.protocol.portOffsetSecret)
for _, s := range [][]byte{
[]byte(e.ID.LocalAddress),
[]byte(e.ID.RemoteAddress),
@@ -2329,6 +2345,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.setEndpointState(StateEstablished)
+ // Set the new auto tuned send buffer size after entering
+ // established state.
+ e.ops.SetSendBufferSize(e.computeTCPSendBufferSize(), false /* notify */)
}
if run {
@@ -2355,6 +2374,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
+
+ if e.EndpointState().connecting() {
+ // When calling shutdown(2) on a connecting socket, the endpoint must
+ // enter the error state. But this logic cannot belong to the shutdownLocked
+ // method because that method is called during a close(2) (and closing a
+ // connecting socket is not an error).
+ e.resetConnectionLocked(&tcpip.ErrConnectionReset{})
+ e.notifyProtocolGoroutine(notifyShutdown)
+ e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
+ return nil
+ }
+
return e.shutdownLocked(flags)
}
@@ -2455,21 +2486,22 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.accepted == (accepted{}) {
- // listen is called after shutdown.
- e.accepted.cap = backlog
- e.shutdownFlags = 0
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcvQueueInfo.RcvClosed = false
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- } else {
- // Adjust the size of the backlog iff we can fit
- // existing pending connections into the new one.
- if e.accepted.endpoints.Len() > backlog {
- return &tcpip.ErrInvalidEndpointState{}
- }
- e.accepted.cap = backlog
+
+ // Adjust the size of the backlog iff we can fit
+ // existing pending connections into the new one.
+ if e.acceptQueue.endpoints.Len() > backlog {
+ return &tcpip.ErrInvalidEndpointState{}
}
+ e.acceptQueue.capacity = backlog
+
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+
+ e.shutdownFlags = 0
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = false
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
// Notify any blocked goroutines that they can attempt to
// deliver endpoints again.
@@ -2505,8 +2537,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.accepted == (accepted{}) {
- e.accepted.cap = backlog
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+ if e.acceptQueue.capacity == 0 {
+ e.acceptQueue.capacity = backlog
}
e.acceptMu.Unlock()
@@ -2546,8 +2581,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
// Get the new accepted endpoint.
var n *endpoint
e.acceptMu.Lock()
- if element := e.accepted.endpoints.Front(); element != nil {
- n = e.accepted.endpoints.Remove(element).(*endpoint)
+ if element := e.acceptQueue.endpoints.Front(); element != nil {
+ n = e.acceptQueue.endpoints.Remove(element).(*endpoint)
}
e.acceptMu.Unlock()
if n == nil {
@@ -2763,13 +2798,20 @@ func (e *endpoint) updateSndBufferUsage(v int) {
e.sndQueueInfo.sndQueueMu.Lock()
notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1
e.sndQueueInfo.SndBufUsed -= v
+
+ // Get the new send buffer size with auto tuning, but do not set it
+ // unless we decide to notify the writers.
+ newSndBufSz := e.computeTCPSendBufferSize()
+
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1
+ notify = notify && e.sndQueueInfo.SndBufUsed < int(newSndBufSz)>>1
e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
+ // Set the new send buffer size calculated from auto tuning.
+ e.ops.SetSendBufferSize(newSndBufSz, false /* notify */)
e.waiterQueue.Notify(waiter.WritableEvents)
}
}
@@ -2873,46 +2915,29 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
// the SYN options indicate that timestamp option was negotiated. It also
// initializes the recentTS with the value provided in synOpts.TSval.
-func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableTimestamp(synOpts header.TCPSynOptions) {
if synOpts.TS {
e.SendTSOk = true
e.setRecentTimestamp(synOpts.TSVal)
}
}
-// timestamp returns the timestamp value to be used in the TSVal field of the
-// timestamp option for outgoing TCP segments for a given endpoint.
-func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset)
+func (e *endpoint) tsVal(now tcpip.MonotonicTime) uint32 {
+ return e.TSOffset.TSVal(now)
}
-// tcpTimeStamp returns a timestamp offset by the provided offset. This is
-// not inlined above as it's used when SYN cookies are in use and endpoint
-// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
- d := curTime.Sub(tcpip.MonotonicTime{})
- return uint32(d.Milliseconds()) + offset
+func (e *endpoint) tsValNow() uint32 {
+ return e.tsVal(e.stack.Clock().NowMonotonic())
}
-// timeStampOffset returns a randomized timestamp offset to be used when sending
-// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset(rng *rand.Rand) uint32 {
- // Initialize a random tsOffset that will be added to the recentTS
- // everytime the timestamp is sent when the Timestamp option is enabled.
- //
- // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
- // why this is required.
- //
- // NOTE: This is not completely to spec as normally this should be
- // initialized in a manner analogous to how sequence numbers are
- // randomized per connection basis. But for now this is sufficient.
- return rng.Uint32()
+func (e *endpoint) elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return e.TSOffset.Elapsed(now, tsEcr)
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
-func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableSACKPermitted(synOpts header.TCPSynOptions) {
var v tcpip.TCPSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
@@ -2974,6 +2999,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
}
s.Sender.RACKState = e.snd.rc.TCPRACKState
+ s.Sender.RetransmitTS = e.snd.retransmitTS
+ s.Sender.SpuriousRecovery = e.snd.spuriousRecovery
return s
}
@@ -3091,3 +3118,36 @@ func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOpti
Max: ss.Max,
}
}
+
+// computeTCPSendBufferSize implements auto tuning of send buffer size and
+// returns the new send buffer size.
+func (e *endpoint) computeTCPSendBufferSize() int64 {
+ curSndBufSz := int64(e.getSendBufferSize())
+
+ // Auto tuning is disabled when the user explicitly sets the send
+ // buffer size with SO_SNDBUF option.
+ if disabled := atomic.LoadUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled); disabled == 1 {
+ return curSndBufSz
+ }
+
+ const packetOverheadFactor = 2
+ curMSS := e.snd.MaxPayloadSize
+ numSeg := InitialCwnd
+ if numSeg < e.snd.SndCwnd {
+ numSeg = e.snd.SndCwnd
+ }
+
+ // SndCwnd indicates the number of segments that can be sent. This means
+ // that the sender can send upto #SndCwnd segments and the send buffer
+ // size should be set to SndCwnd*MSS to accommodate sending of all the
+ // segments.
+ newSndBufSz := int64(numSeg * curMSS * packetOverheadFactor)
+ if newSndBufSz < curSndBufSz {
+ return curSndBufSz
+ }
+ if ss := GetTCPSendBufferLimits(e.stack); int64(ss.Max) < newSndBufSz {
+ newSndBufSz = int64(ss.Max)
+ }
+
+ return newSndBufSz
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 952ccacdd..94072a115 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() {
}
// saveEndpoints is invoked by stateify.
-func (a *accepted) saveEndpoints() []*endpoint {
+func (a *acceptQueue) saveEndpoints() []*endpoint {
acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
acceptedEndpoints[i] = e.Value.(*endpoint)
@@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint {
}
// loadEndpoints is invoked by stateify.
-func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) {
for _, ep := range acceptedEndpoints {
a.endpoints.PushBack(ep)
}
@@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
snd.probeTimer.init(s.Clock(), &snd.probeWaker)
}
e.stack = s
+ e.protocol = protocolFromStack(s)
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := EndpointState(e.origEndpointState)
@@ -250,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
- backlog := e.accepted.cap
+ e.acceptMu.Lock()
+ backlog := e.acceptQueue.capacity
+ e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2e709ed78..128ef09e3 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
- listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
}
}
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 18b834243..e4410ad93 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,8 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -49,10 +51,6 @@ const (
// MaxBufferSize is the largest size a receive/send buffer can grow to.
MaxBufferSize = 4 << 20 // 4MB
- // MaxUnprocessedSegments is the maximum number of unprocessed segments
- // that can be queued for a given endpoint.
- MaxUnprocessedSegments = 300
-
// DefaultTCPLingerTimeout is the amount of time that sockets linger in
// FIN_WAIT_2 state before being marked closed.
DefaultTCPLingerTimeout = 60 * time.Second
@@ -96,6 +94,11 @@ type protocol struct {
maxRetries uint32
synRetries uint8
dispatcher dispatcher
+
+ // The following secrets are initialized once and stay unchanged after.
+ seqnumSecret uint32
+ portOffsetSecret uint32
+ tsOffsetSecret uint32
}
// Number returns the tcp protocol number.
@@ -105,7 +108,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
// NewEndpoint creates a new tcp endpoint.
func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
- return newEndpoint(p.stack, netProto, waiterQueue), nil
+ return newEndpoint(p.stack, p, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
@@ -156,6 +159,24 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
return stack.UnknownDestinationPacketHandled
}
+func (p *protocol) tsOffset(src, dst tcpip.Address) tcp.TSOffset {
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // TODO(https://gvisor.dev/issues/6473): This is not really secure as
+ // it does not use the recommended algorithm linked above.
+ h := jenkins.Sum32(p.tsOffsetSecret)
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = h.Write([]byte(src))
+ _, _ = h.Write([]byte(dst))
+ return tcp.NewTSOffset(h.Sum32())
+}
+
// replyWithReset replies to the given segment with a reset segment.
//
// If the passed TTL is 0, then the route's default TTL will be used.
@@ -292,22 +313,26 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
case *tcpip.TCPMinRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.minRTO = MinRTO
+ } else if minRTO := time.Duration(*v); minRTO <= p.maxRTO {
+ p.minRTO = minRTO
} else {
- p.minRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.maxRTO = MaxRTO
+ } else if maxRTO := time.Duration(*v); maxRTO >= p.minRTO {
+ p.maxRTO = maxRTO
} else {
- p.maxRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRetriesOption:
@@ -479,7 +504,15 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
maxRTO: MaxRTO,
maxRetries: MaxRetries,
recovery: tcpip.TCPRACKLossDetection,
+ seqnumSecret: s.Rand().Uint32(),
+ portOffsetSecret: s.Rand().Uint32(),
+ tsOffsetSecret: s.Rand().Uint32(),
}
p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
+
+// protocolFromStack retrieves the tcp.protocol instance from stack s.
+func protocolFromStack(s *stack.Stack) *protocol {
+ return s.TransportProtocolInstance(ProtocolNumber).(*protocol)
+}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 0da4eafaa..3b055c294 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -80,7 +80,6 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime)
- tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
@@ -92,7 +91,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// step 2
if seg.xmitCount > 1 {
if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
- if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) {
+ if ackSeg.parsedOptions.TSEcr < rc.snd.ep.tsVal(seg.xmitTime) {
return
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 9ce8fcae9..90e493978 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 {
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
r.PendingBufUsed += s.segMemSize()
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
index 8a026ec46..e47a07030 100644
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rcv_test
+package tcp_test
import (
"testing"
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
index 2e6ea06f5..2d5fdda19 100644
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW
DataSize: seg.data.Size(),
SegMemSize: seg.segMemSize(),
}
- if diff := cmp.Diff(got, want); diff != "" {
+ if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("%s differs (-want +got):\n%s", name, diff)
}
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 92a66f17e..4377f07a0 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -144,6 +144,15 @@ type sender struct {
// probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
probeTimer timer `state:"nosave"`
probeWaker sleep.Waker `state:"nosave"`
+
+ // spuriousRecovery indicates whether the sender entered recovery
+ // spuriously as described in RFC3522 Section 3.2.
+ spuriousRecovery bool
+
+ // retransmitTS is the timestamp at which the sender sends retransmitted
+ // segment after entering an RTO for the first time as described in
+ // RFC3522 Section 3.2.
+ retransmitTS uint32
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -382,6 +391,9 @@ func (s *sender) updateRTO(rtt time.Duration) {
if s.RTO < s.minRTO {
s.RTO = s.minRTO
}
+ if s.RTO > s.maxRTO {
+ s.RTO = s.maxRTO
+ }
}
// resendSegment resends the first unacknowledged segment.
@@ -422,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // Initialize the variables used to detect spurious recovery after
+ // entering RTO.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
// TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
// when writeList is empty. Remove this once we have a proper fix for this
// issue.
@@ -492,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveRecovery()
}
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
s.state = tcpip.RTORecovery
s.cc.HandleRTOExpired()
@@ -955,6 +978,13 @@ func (s *sender) sendData() {
}
func (s *sender) enterRecovery() {
+ // Initialize the variables used to detect spurious recovery after
+ // entering recovery.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
s.FastRecovery.Active = true
// Save state to reflect we're now in fast recovery.
//
@@ -969,6 +999,11 @@ func (s *sender) enterRecovery() {
s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding
s.FastRecovery.HighRxt = s.SndUna
s.FastRecovery.RescueRxt = s.SndUna
+
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
if s.ep.SACKPermitted {
s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
@@ -1144,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool {
// Iterate the writeList and update RACK for each segment which is newly acked
// either cumulatively or selectively. Loop through the segments which are
// sacked, and update the RACK related variables and check for reordering.
+// Returns true when the DSACK block has been detected in the received ACK.
//
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
-func (s *sender) walkSACK(rcvdSeg *segment) {
+func (s *sender) walkSACK(rcvdSeg *segment) bool {
s.rc.setDSACKSeen(false)
// Look for DSACK block.
+ hasDSACK := false
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
@@ -1164,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.setDSACKSeen(true)
idx = 1
n--
+ hasDSACK = true
}
if n == 0 {
- return
+ return hasDSACK
}
// Sort the SACK blocks. The first block is the most recent unacked
@@ -1190,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
seg = seg.Next()
}
}
+ return hasDSACK
}
// checkDSACK checks if a DSACK is reported.
@@ -1236,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool {
return false
}
+func (s *sender) recordRetransmitTS() {
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2
+ //
+ // The Eifel detection algorithm is used, only upon initiation of loss
+ // recovery, i.e., when either the timeout-based retransmit or the fast
+ // retransmit is sent. The Eifel detection algorithm MUST NOT be
+ // reinitiated after loss recovery has already started. In particular,
+ // it must not be reinitiated upon subsequent timeouts for the same
+ // segment, and not upon retransmitting segments other than the oldest
+ // outstanding segment, e.g., during selective loss recovery.
+ if s.inRecovery() {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ //
+ // Set a "RetransmitTS" variable to the value of the Timestamp Value
+ // field of the Timestamps option included in the retransmit sent when
+ // loss recovery is initiated. A TCP sender must ensure that
+ // RetransmitTS does not get overwritten as loss recovery progresses,
+ // e.g., in case of a second timeout and subsequent second retransmit of
+ // the same octet.
+ s.retransmitTS = s.ep.tsValNow()
+}
+
+func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) {
+ // Return if the sender has already detected spurious recovery.
+ if s.spuriousRecovery {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4
+ //
+ // If the value of the Timestamp Echo Reply field of the acceptable ACK's
+ // Timestamps option is smaller than the value of RetransmitTS, then
+ // proceed to next step, else return.
+ if tsEchoReply >= s.retransmitTS {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If the acceptable ACK carries a DSACK option [RFC2883], then return.
+ if hasDSACK {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If during the lifetime of the TCP connection the TCP sender has
+ // previously received an ACK with a DSACK option, or the acceptable ACK
+ // does not acknowledge all outstanding data, then proceed to next step,
+ // else return.
+ numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value()
+ if numDSACK == 0 && s.SndUna == s.SndNxt {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6
+ //
+ // If the loss recovery has been initiated with a timeout-based
+ // retransmit, then set
+ // SpuriousRecovery <- SPUR_TO (equal 1),
+ // else set
+ // SpuriousRecovery <- dupacks+1
+ // Set the spurious recovery variable to true as we do not differentiate
+ // between fast, SACK or RTO recovery.
+ s.spuriousRecovery = true
+ s.ep.stack.Stats().TCP.SpuriousRecovery.Increment()
+}
+
+// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state.
+func (s *sender) inRecovery() bool {
+ if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery {
+ return true
+ }
+ return false
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1251,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Insert SACKBlock information into our scoreboard.
+ hasDSACK := false
if s.ep.SACKPermitted {
for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
@@ -1285,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// RACK.fack, then the corresponding packet has been
// reordered and RACK.reord is set to TRUE.
if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
- s.walkSACK(rcvdSeg)
+ hasDSACK = s.walkSACK(rcvdSeg)
}
s.SetPipe()
}
@@ -1342,10 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// some new data, i.e., only if it advances the left edge of
// the send window.
if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
- // TSVal/Ecr values sent by Netstack are at a millisecond
- // granularity.
- elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
- s.updateRTO(elapsed)
+ s.updateRTO(s.ep.elapsed(s.ep.stack.Clock().NowMonotonic(), rcvdSeg.parsedOptions.TSEcr))
}
if s.shouldSchedulePTO() {
@@ -1415,12 +1531,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ackLeft -= datalen
}
- // Update the send buffer usage and notify potential waiters.
- s.ep.updateSndBufferUsage(int(acked))
-
// Clear SACK information for all acked data.
s.ep.scoreboard.Delete(s.SndUna)
+ // Detect if the sender entered recovery spuriously.
+ if s.inRecovery() {
+ s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr)
+ }
+
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
if !s.FastRecovery.Active {
@@ -1437,6 +1555,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
// get an ack that cover previously sent data.
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 89e9fb886..0d36d0dd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -33,7 +33,6 @@ const (
tsOptionSize = 12
maxTCPOptionSize = 40
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
- latency = 5 * time.Millisecond
)
func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) {
@@ -163,7 +162,10 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
if !enableRACK {
setStackTCPRecovery(t, c, 0)
}
- createConnectedWithSACKAndTS(c)
+ // The delay should be below initial RTO (1s) otherwise retransimission
+ // will start. Choose a relatively large value so that estimated RTT
+ // keeps high even after a few rounds of undelayed RTT samples.
+ c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */)
data := make([]byte, numPackets*maxPayload)
for i := range data {
@@ -181,9 +183,6 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
for i := 0; i < numPackets; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- // This delay is added to increase RTT as low RTT can cause TLP
- // before sending ACK.
- time.Sleep(latency)
}
return data
@@ -1060,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) {
for i := 0; i < numPkts; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- if i == 0 {
- // Send ACK for the first packet to establish RTT.
- c.SendAck(seq, maxPayload)
- }
}
- // SACK for #10 packet.
- start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ // Expect retransmission of last packet due to TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize)
+
+ // SACK for first and last packet.
+ start := c.IRS.Add(seqnum.Size(maxPayload))
end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}})
+ dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
var info tcpip.TCPInfoOption
if err := c.EP.GetSockOpt(&info); err != nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 83e0653b9..896249d2d 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -23,6 +23,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -35,13 +36,13 @@ import (
// SACKPermitted option enabled if the stack in the context has the SACK support
// enabled.
func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
}
// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
// option enabled if the stack in the context has SACK and TS enabled.
func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
}
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
@@ -108,7 +109,7 @@ func TestSackDisabledConnect(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
data := []byte{1, 2, 3}
@@ -170,7 +171,7 @@ func TestSackPermittedAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
// Now verify no SACK blocks are
// received when sack is disabled.
data := []byte{1, 2, 3}
@@ -244,7 +245,7 @@ func TestSackDisabledAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now verify no SACK blocks are
// received when sack is disabled.
@@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) {
t.Error(err)
}
}
+
+func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) {
+ t.Helper()
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+}
+
+func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) {
+ payloadLen := uint32(len(tcpHdr.Payload()))
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead),
+ checker.TCPAckNum(context.TestInitialSequenceNumber+1),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
+ ),
+ )
+ pdata := data[bytesRead : bytesRead+payloadLen]
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
+ t.Fatalf("got data = %v, want = %v", p, pdata)
+ }
+}
+
+func buildTSOptionFromHeader(tcpHdr header.TCP) []byte {
+ parsedOpts := tcpHdr.ParsedOptions()
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ return tsOpt[:]
+}
+
+func TestDetectSpuriousRecoveryWithRTO(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Expect #5 segment with TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Expect #1 segment because of RTO.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.RTORecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ numAck := 0
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if numAck < 3 {
+ numAck++
+ return
+ }
+
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.SACKRecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestNoSpuriousRecoveryWithDSACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ // Acknowledge the data with DSACK for #1 segment.
+ start = c.IRS.Add(maxPayload + 1)
+ end = start.Add(2 * maxPayload)
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}})
+
+ verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 031f01357..6f1ee3816 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -1381,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) {
if err := s.CreateNIC(id, ep); err != nil {
t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
- t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: id},
@@ -1651,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
+func TestShutdownConnectingSocket(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ shutdownMode tcpip.ShutdownFlags
+ }{
+ {"ShutdownRead", tcpip.ShutdownRead},
+ {"ShutdownWrite", tcpip.ShutdownWrite},
+ {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with
+ // the handshake process.
+ c.Create(-1)
+
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ // Start connection attempt.
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
+ }
+
+ // Check the SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ if err := c.EP.Shutdown(test.shutdownMode); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ // The endpoint internal state is updated immediately.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ select {
+ case <-ch:
+ default:
+ t.Fatal("endpoint was not notified")
+ }
+
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, &tcpip.ErrConnectionReset{})
+
+ // If the endpoint is not properly shutdown, it'll re-attempt to connect
+ // by sending another ACK packet.
+ c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond))
+ })
+ }
+}
+
func TestSynSent(t *testing.T) {
for _, test := range []struct {
name string
@@ -1674,7 +1744,7 @@ func TestSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
err := c.EP.Connect(addr)
- if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
@@ -1990,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
)
// Cause a FIN to be generated.
- c.EP.Shutdown(tcpip.ShutdownWrite)
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the FIN but DON't ACK IT.
checker.IPv4(t, c.GetPacket(),
@@ -2006,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// Cause a RST to be generated by closing the read end now since we have
// unread data.
- c.EP.Shutdown(tcpip.ShutdownRead)
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the RST
checker.IPv4(t, c.GetPacket(),
@@ -2127,6 +2201,214 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+func TestSmallReceiveBufferReadiness(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+
+ const nicID = 1
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
+ }
+
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ t.Fatalf("tcpip.NewSubnet failed: %s", err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ listenerEntry, listenerCh := waiter.NewChannelEntry(nil)
+ var listenerWQ waiter.Queue
+ listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer listener.Close()
+ listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents)
+ defer listenerWQ.EventUnregister(&listenerEntry)
+
+ if err := listener.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := listener.Listen(1); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ localAddress, err := listener.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress failed: %s", err)
+ }
+
+ for i := 8; i > 0; i /= 2 {
+ size := int64(i << 10)
+ t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) {
+ var clientWQ waiter.Queue
+ client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer client.Close()
+ switch err := client.Connect(localAddress).(type) {
+ case nil:
+ t.Fatal("Connect returned nil error")
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect failed: %s", err)
+ }
+
+ <-listenerCh
+ server, serverWQ, err := listener.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+ defer server.Close()
+
+ client.SocketOptions().SetReceiveBufferSize(size, true)
+ // Send buffer size doesn't seem to affect this test.
+ // server.SocketOptions().SetSendBufferSize(size, true)
+
+ clientEntry, clientCh := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents)
+ defer clientWQ.EventUnregister(&clientEntry)
+
+ serverEntry, serverCh := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverEntry, waiter.WritableEvents)
+ defer serverWQ.EventUnregister(&serverEntry)
+
+ var total int64
+ for {
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ continue
+ case *tcpip.ErrWouldBlock:
+ select {
+ case <-serverCh:
+ continue
+ case <-time.After(100 * time.Millisecond):
+ // Well and truly full.
+ t.Logf("send and receive queues are full")
+ }
+ default:
+ t.Fatalf("Write failed: %s", err)
+ }
+ break
+ }
+ t.Logf("wrote %d bytes in total", total)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ if err := func() error {
+ var total int64
+ defer t.Logf("wrote %d bytes in total", total)
+ for r.Len() != 0 {
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on server")
+ select {
+ case <-serverCh:
+ case <-time.After(time.Second):
+ if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 {
+ t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("server.Write failed: %s", err)
+ }
+ }
+ if err := server.Shutdown(tcpip.ShutdownWrite); err != nil {
+ return fmt.Errorf("server.Shutdown failed: %s", err)
+ }
+ t.Logf("server end shutdown done")
+ return nil
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ if err := func() error {
+ total := 0
+ defer t.Logf("read %d bytes in total", total)
+ for {
+ switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
+ case nil:
+ t.Logf("read %d bytes", res.Count)
+ total += res.Count
+ t.Logf("read total %d bytes till now", total)
+ case *tcpip.ErrClosedForReceive:
+ return nil
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on client")
+ select {
+ case <-clientCh:
+ case <-time.After(time.Second):
+ if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 {
+ return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("client.Write failed: %s", err)
+ }
+ }
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+ })
+ }
+}
+
// Test the stack receive window advertisement on receiving segments smaller than
// segment overhead. It tests for the right edge of the window to not grow when
// the endpoint is not being read from.
@@ -2143,7 +2425,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
}
- c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
@@ -2535,7 +2817,7 @@ func TestScaledWindowAccept(t *testing.T) {
// Do 3-way handshake.
// wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
- c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -3532,6 +3814,12 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
+ // Wait for the connection to timeout after MaxRetries retransmits.
+ initRTO := time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -3554,8 +3842,6 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
),
)
}
- // Wait for the connection to timeout after MaxRetries retransmits.
- initRTO := 1 * time.Second
select {
case <-notifyCh:
case <-time.After((2 << numRetries) * initRTO):
@@ -3590,9 +3876,13 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- opt := tcpip.TCPMaxRTOOption(rto)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ minRTOOpt := tcpip.TCPMinRTOOption(rto / 2)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
+ maxRTOOpt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err)
}
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3618,8 +3908,8 @@ func TestMaxRTO(t *testing.T) {
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
- if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
- t.Errorf("Retransmit interval not capped to MaxRTO.\n")
+ if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() {
+ t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed)
}
}
}
@@ -3670,6 +3960,10 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ minRTOOpt := tcpip.TCPMinRTOOption(time.Second)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
// Disabling PMTU discovery causes all packets sent from this socket to
@@ -4736,13 +5030,17 @@ func makeStack() (*stack.Stack, tcpip.Error) {
}
for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- address tcpip.Address
+ number tcpip.NetworkProtocolNumber
+ addrWithPrefix tcpip.AddressWithPrefix
}{
- {ipv4.ProtocolNumber, context.StackAddr},
- {ipv6.ProtocolNumber, context.StackV6Addr},
+ {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
+ {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
} {
- if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ct.number,
+ AddressWithPrefix: ct.addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
return nil, err
}
}
@@ -4946,7 +5244,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
}
for i := start; i <= end; i++ {
- if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
+ if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
@@ -6304,7 +6602,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -6385,7 +6683,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// maximum buffer size defined above.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
// NOTE: The timestamp values in the sent packets are meaningless to the
// peer so we just increment the timestamp value by 1 every batch as we
@@ -6515,7 +6813,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
tsVal := rawEP.TSVal
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
@@ -7430,6 +7728,11 @@ func TestTCPUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ initRTO := 1 * time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -7440,7 +7743,6 @@ func TestTCPUserTimeout(t *testing.T) {
// Ensure that on the next retransmit timer fire, the user timeout has
// expired.
- initRTO := 1 * time.Second
userTimeout := initRTO / 2
v := tcpip.TCPUserTimeoutOption(userTimeout)
if err := c.EP.SetSockOpt(&v); err != nil {
@@ -7954,6 +8256,151 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
}
+func TestHandshakeRTT(t *testing.T) {
+ type testCase struct {
+ connect bool
+ tsEnabled bool
+ useCookie bool
+ retrans bool
+ delay time.Duration
+ wantRTT time.Duration
+ }
+ var testCases []testCase
+ for _, connect := range []bool{false, true} {
+ for _, tsEnabled := range []bool{false, true} {
+ for _, useCookie := range []bool{false, true} {
+ for _, retrans := range []bool{false, true} {
+ if connect && useCookie {
+ continue
+ }
+ delay := 800 * time.Millisecond
+ if retrans {
+ delay = 1200 * time.Millisecond
+ }
+ wantRTT := delay
+ // If syncookie is enabled, sample RTT only when TS option is enabled.
+ if !retrans && useCookie && !tsEnabled {
+ wantRTT = 0
+ }
+ // If retransmitted, sample RTT only when TS option is enabled.
+ if retrans && !tsEnabled {
+ wantRTT = 0
+ }
+ testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT})
+ }
+ }
+ }
+ }
+ for _, tt := range testCases {
+ tt := tt
+ t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) {
+ t.Parallel()
+ c := context.New(t, defaultMTU)
+ if tt.useCookie {
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ }
+ synOpts := header.TCPSynOptions{}
+ if tt.tsEnabled {
+ synOpts.TS = true
+ synOpts.TSVal = 42
+ }
+ if tt.connect {
+ c.CreateConnectedWithOptions(synOpts, tt.delay)
+ } else {
+ synOpts.MSS = defaultIPv4MSS
+ synOpts.WS = -1
+ c.AcceptWithOptions(-1, synOpts, tt.delay)
+ }
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+ if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT {
+ t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT)
+ }
+ if info.RTTVar != 0 && tt.wantRTT == 0 {
+ t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar)
+ }
+ if info.RTTVar == 0 && tt.wantRTT != 0 {
+ t.Fatalf("got info.RTTVar=0, expect non zero")
+ }
+ })
+ }
+}
+
+func TestSetRTO(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ for _, tt := range []struct {
+ name string
+ RTO time.Duration
+ minRTO time.Duration
+ maxRTO time.Duration
+ err tcpip.Error
+ }{
+ {
+ name: "invalid minRTO",
+ minRTO: maxRTO + time.Second,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "invalid maxRTO",
+ maxRTO: minRTO - time.Millisecond,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "valid minRTO",
+ minRTO: maxRTO - time.Second,
+ },
+ {
+ name: "valid maxRTO",
+ maxRTO: minRTO + time.Millisecond,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ var opt tcpip.SettableTransportProtocolOption
+ if tt.minRTO > 0 {
+ min := tcpip.TCPMinRTOOption(tt.minRTO)
+ opt = &min
+ }
+ if tt.maxRTO > 0 {
+ max := tcpip.TCPMaxRTOOption(tt.maxRTO)
+ opt = &max
+ }
+ err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt)
+ if got, want := err, tt.err; got != want {
+ t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want)
+ }
+ if tt.err == nil {
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ if tt.minRTO > 0 && tt.minRTO != minRTO {
+ t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO)
+ }
+ if tt.maxRTO > 0 && tt.maxRTO != maxRTO {
+ t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO)
+ }
+ }
+ })
+ }
+}
+
+func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) {
+ t.Helper()
+ var minOpt tcpip.TCPMinRTOOption
+ var maxOpt tcpip.TCPMaxRTOOption
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err)
+ }
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err)
+ }
+ return time.Duration(minOpt), time.Duration(maxOpt)
+}
+
// generateRandomPayload generates a random byte slice of the specified length
// causing a fatal test failure if it is unable to do so.
func generateRandomPayload(t *testing.T, n int) []byte {
@@ -7964,3 +8411,192 @@ func generateRandomPayload(t *testing.T, n int) []byte {
}
return buf
}
+
+func TestSendBufferTuning(t *testing.T) {
+ const maxPayload = 536
+ const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
+ const packetOverheadFactor = 2
+
+ testCases := []struct {
+ name string
+ autoTuningDisabled bool
+ }{
+ {"autoTuningDisabled", true},
+ {"autoTuningEnabled", false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Set the stack option for send buffer size.
+ const defaultSndBufSz = maxPayload * tcp.InitialCwnd
+ const maxSndBufSz = defaultSndBufSz * 10
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz}
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+ }
+
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
+
+ oldSz := c.EP.SocketOptions().GetSendBufferSize()
+ if oldSz != defaultSndBufSz {
+ t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz)
+ }
+
+ if tc.autoTuningDisabled {
+ c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */)
+ }
+
+ data := make([]byte, maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ w, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&w, waiter.WritableEvents)
+ defer c.WQ.EventUnregister(&w)
+
+ bytesRead := 0
+ for {
+ // Packets will be sent till the send buffer
+ // size is reached.
+ var r bytes.Reader
+ r.Reset(data[bytesRead : bytesRead+maxPayload])
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ break
+ }
+
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0)
+ bytesRead += maxPayload
+ data = append(data, data...)
+ }
+
+ // Send an ACK and wait for connection to become writable again.
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
+ select {
+ case <-ch:
+ if err := c.EP.LastError(); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+
+ outSz := int64(defaultSndBufSz)
+ if !tc.autoTuningDisabled {
+ // Calculate the new auto tuned send buffer.
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+ outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload
+ }
+
+ if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz {
+ t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz)
+ }
+ })
+ }
+}
+
+func TestTimestampSynCookies(t *testing.T) {
+ clock := faketime.NewManualClock()
+ tsNow := func() uint32 {
+ return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds())
+ }
+ // Advance the clock so that NowMonotonic is non-zero.
+ clock.Advance(time.Second)
+ c := context.NewWithOpts(t, context.Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: defaultMTU,
+ Clock: clock,
+ })
+ defer c.Cleanup()
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(42, 0, tcpOpts[2:])
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss,
+ TCPOpts: tcpOpts[:],
+ })
+ // Get the TSVal of SYN-ACK.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ initialTSVal := tcpHdr.ParsedOptions().TSVal
+ // derive the tsOffset.
+ tsOffset := initialTSVal - tsNow()
+
+ header.EncodeTSOption(420, initialTSVal, tcpOpts[2:])
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ TCPOpts: tcpOpts[:],
+ })
+ c.EP, _, err = ep.Accept(nil)
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ defer wq.EventUnregister(&we)
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ } else if err != nil {
+ t.Fatalf("failed to accept: %s", err)
+ }
+
+ // Advance the clock again so that we expect the next TSVal to change.
+ clock.Advance(time.Second)
+ data := []byte{1, 2, 3}
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // The endpoint should have a correct TSOffset so that the received TSVal
+ // should match our expectation.
+ if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want {
+ t.Fatalf("got TSVal = %d, want %d", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 1deb1fe4d..65925daa5 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -32,7 +32,7 @@ import (
// createConnectedWithTimestampOption creates and connects c.ep with the
// timestamp option enabled.
func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1})
}
// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
@@ -131,7 +131,7 @@ func TestTimeStampDisabledConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
@@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
@@ -209,7 +209,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
}
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 96e4849d2..88bb99354 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -122,6 +122,9 @@ type Options struct {
// MTU indicates the maximum transmission unit on the link layer.
MTU uint32
+
+ // Clock that is used by Stack.
+ Clock tcpip.Clock
}
// Context provides an initialized Network stack and a link layer endpoint
@@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
stackOpts := stack.Options{
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: opts.Clock,
}
if opts.EnableV4 {
stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
@@ -239,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: StackAddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
@@ -253,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: StackV6AddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
@@ -879,13 +883,21 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
)
}
+// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions
+// without delay.
+func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint {
+ return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */)
+}
+
// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
// options enabled and returns a RawEndpoint which represents the other end of
-// the connection.
+// the connection. It delays before a SYNACK is sent. This makes c.EP have a
+// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps
+// reduce flakiness.
//
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
@@ -911,18 +923,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// TS value.
mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- ),
+ synChecker := checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{
+ MSS: mss,
+ TS: true,
+ WS: int(c.WindowScale),
+ SACKPermitted: c.SACKEnabled(),
+ }),
)
+ checker.IPv4(c.t, b, synChecker)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
@@ -948,6 +959,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
iss := seqnum.Value(TestInitialSequenceNumber)
+ if delay > 0 {
+ // Sleep so that RTT is increased.
+ time.Sleep(delay)
+ }
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -959,7 +974,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
})
// Read ACK.
- ackPacket := c.GetPacket()
+ var ackPacket []byte
+ // Ignore retransimitted SYN packets.
+ for {
+ packet := c.GetPacket()
+ if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 {
+ checker.IPv4(c.t, packet, synChecker)
+ } else {
+ ackPacket = packet
+ break
+ }
+ }
// Verify TCP header fields.
tcpCheckers := []checker.TransportChecker{
@@ -1016,13 +1041,19 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
}
}
-// AcceptWithOptions initializes a listening endpoint and connects to it with the
-// provided options enabled. It also verifies that the SYN-ACK has the expected
-// values for the provided options.
+// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay.
+func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */)
+}
+
+// AcceptWithOptions initializes a listening endpoint and connects to it with
+// the provided options enabled. It delays before the final ACK of the 3WHS is
+// sent. It also verifies that the SYN-ACK has the expected values for the
+// provided options.
//
// The function returns a RawEndpoint representing the other end of the accepted
// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
@@ -1045,7 +1076,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1077,13 +1108,14 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
// PassiveConnectWithOptions.
func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
+ c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */)
}
// PassiveConnectWithOptions initiates a new connection (with the specified TCP
// options enabled) to the port on which the Context.ep is listening for new
// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options.
+// the enabled options. The final ACK of the handshake is delayed by specified
+// duration.
//
// NOTE: MSS is not a negotiated option and it can be asymmetric
// in each direction. This function uses the maxPayload to set the MSS to be
@@ -1093,7 +1125,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
@@ -1180,7 +1212,10 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
ackHeaders.TCPOpts = opts[:]
}
- // Send ACK.
+ // Send ACK, delay if needed.
+ if delay > 0 {
+ time.Sleep(delay)
+ }
c.SendPacket(nil, ackHeaders)
c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go
new file mode 100644
index 000000000..4c2ae87f4
--- /dev/null
+++ b/pkg/tcpip/transport/transport.go
@@ -0,0 +1,16 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport supports transport protocols.
+package transport
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index cdc344ab7..d2c0963b0 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -35,6 +35,8 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
],
@@ -61,5 +63,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 82a3f2287..39b1e08c0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,8 @@
package udp
import (
+ "fmt"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -25,12 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
// +stateify savable
type udpPacket struct {
udpPacketEntry
+ netProto tcpip.NetworkProtocolNumber
senderAddress tcpip.FullAddress
destinationAddress tcpip.FullAddress
packetInfo tcpip.IPPacketInfo
@@ -40,36 +43,6 @@ type udpPacket struct {
tos uint8
}
-// EndpointState represents the state of a UDP endpoint.
-type EndpointState tcpip.EndpointState
-
-// Endpoint states. Note that are represented in a netstack-specific manner and
-// may not be meaningful externally. Specifically, they need to be translated to
-// Linux's representation for these states if presented to userspace.
-const (
- _ EndpointState = iota
- StateInitial
- StateBound
- StateConnected
- StateClosed
-)
-
-// String implements fmt.Stringer.
-func (s EndpointState) String() string {
- switch s {
- case StateInitial:
- return "INITIAL"
- case StateBound:
- return "BOUND"
- case StateConnected:
- return "CONNECTING"
- case StateClosed:
- return "CLOSED"
- default:
- return "UNKNOWN"
- }
-}
-
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -79,7 +52,6 @@ func (s EndpointState) String() string {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
@@ -87,6 +59,9 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -96,37 +71,19 @@ type endpoint struct {
rcvBufSize int
rcvClosed bool
- // The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- // state must be read/set using the EndpointState()/setEndpointState()
- // methods.
- state uint32
- route *stack.Route `state:"manual"`
- dstPort uint16
- ttl uint8
- multicastTTL uint8
- multicastAddr tcpip.Address
- multicastNICID tcpip.NICID
- portFlags ports.Flags
-
lastErrorMu sync.Mutex `state:"nosave"`
lastError tcpip.Error
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
- // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
- // applied while sending packets. Defaults to 0 as on Linux.
- sendTOS uint8
-
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
-
- // multicastMemberships that need to be remvoed when the endpoint is
- // closed. Protected by the mu mutex.
- multicastMemberships map[multicastMembership]struct{}
+ readShutdown bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -136,55 +93,25 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
-}
-// +stateify savable
-type multicastMembership struct {
- nicID tcpip.NICID
- multicastAddr tcpip.Address
+ localPort uint16
+ remotePort uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.UDPProtocolNumber,
- },
+ stack: s,
waiterQueue: waiterQueue,
- // RFC 1075 section 5.4 recommends a TTL of 1 for membership
- // requests.
- //
- // RFC 5135 4.2.1 appears to assume that IGMP messages have a
- // TTL of 1.
- //
- // RFC 5135 Appendix A defines TTL=1: A multicast source that
- // wants its traffic to not traverse a router (e.g., leave a
- // home network) may find it useful to send traffic with IP
- // TTL=1.
- //
- // Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastMemberships: make(map[multicastMembership]struct{}),
- state: uint32(StateInitial),
- uniqueID: s.UniqueID(),
+ uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -200,20 +127,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
-// setEndpointState updates the state of the endpoint to state atomically. This
-// method is unexported as the only place we should update the state is in this
-// package but we allow the state to be read freely without holding e.mu.
-//
-// Precondition: e.mu must be held to call this method.
-func (e *endpoint) setEndpointState(state EndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
-// EndpointState() returns the current state of the endpoint.
-func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32(&e.state))
-}
-
// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -244,16 +157,22 @@ func (e *endpoint) Abort() {
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.EndpointState() {
- case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ e.mu.Unlock()
+ return
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -261,13 +180,10 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(portRes)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- for mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
- }
- e.multicastMemberships = make(map[multicastMembership]struct{})
-
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -278,14 +194,9 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- // Update the state.
- e.setEndpointState(StateClosed)
-
+ e.net.Shutdown()
+ e.net.Close()
+ e.readShutdown = true
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
@@ -322,21 +233,38 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
- }
- if e.ops.GetReceiveTOS() {
- cm.HasTOS = true
- cm.TOS = p.tos
- }
- if e.ops.GetReceiveTClass() {
- cm.HasTClass = true
- // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
- cm.TClass = uint32(p.tos)
+ Timestamp: p.receivedAt,
}
- if e.ops.GetReceivePacketInfo() {
- cm.HasIPPacketInfo = true
- cm.PacketInfo = p.packetInfo
+
+ switch p.netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceiveTOS() {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+
+ if e.ops.GetReceivePacketInfo() {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetReceiveTClass() {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ cm.HasIPv6PacketInfo = true
+ cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: p.packetInfo.NIC,
+ Addr: p.packetInfo.DestinationAddr,
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
}
+
if e.ops.GetReceiveOriginalDstAddress() {
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
@@ -359,19 +287,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
return res, nil
}
-// prepareForWrite prepares the endpoint for sending data. In particular, it
-// binds it if it's still in the initial state. To do so, it must first
+// prepareForWriteInner prepares the endpoint for sending data. In particular,
+// it binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.EndpointState() {
- case StateInitial:
- case StateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
- case StateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -386,7 +314,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -398,33 +326,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
return true, nil
}
-// connectRoute establishes a route to the specified interface or the
-// configured multicast interface if no interface is specified and the
-// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.ID.LocalAddress
- if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
- // A packet can only originate from a unicast address (i.e., an interface).
- localAddr = ""
- }
-
- if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicID == 0 {
- nicID = e.multicastNICID
- }
- if localAddr == "" && nicID == 0 {
- localAddr = e.multicastAddr
- }
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
- if err != nil {
- return nil, 0, err
- }
- return r, nicID, nil
-}
-
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
@@ -448,18 +349,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
+func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return udpPacketInfo{}, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(opts.To)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
return udpPacketInfo{}, err
}
@@ -469,49 +365,28 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
}
}
- route := e.route
- dstPort := e.dstPort
+ dst, connected := e.net.GetRemoteAddress()
+ dst.Port = e.remotePort
if opts.To != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := opts.To.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return udpPacketInfo{}, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
- if err != nil {
- return udpPacketInfo{}, err
- }
-
- r, _, err := e.connectRoute(nicID, dst, netProto)
- if err != nil {
- return udpPacketInfo{}, err
- }
- defer r.Release()
-
- route = r
- dstPort = dst.Port
+ dst = *opts.To
+ } else if !connected {
+ return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
}
- if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return udpPacketInfo{}, err
}
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
@@ -520,50 +395,25 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
if so.GetRecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
- route.NetProto(),
+ e.net.NetProto(),
header.UDPMaximumPacketSize,
- tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress(),
- Port: dstPort,
- },
+ dst,
v,
)
}
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
- ttl := e.ttl
- useDefaultTTL := ttl == 0
- if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
- ttl = e.multicastTTL
- // Multicast allows a 0 TTL.
- useDefaultTTL = false
- }
-
return udpPacketInfo{
- route: route,
- data: buffer.View(v),
- localPort: e.ID.LocalPort,
- remotePort: dstPort,
- ttl: ttl,
- useDefaultTTL: useDefaultTTL,
- tos: e.sendTOS,
- owner: e.owner,
- noChecksum: e.SocketOptions().GetNoChecksum(),
+ ctx: ctx,
+ data: v,
+ localPort: e.localPort,
+ remotePort: dst.Port,
}, nil
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
// deadlock where the ICMP response handling ends up acquiring this endpoint's
@@ -574,15 +424,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- u, err := e.buildUDPPacketInfo(p, opts)
- if err != nil {
+
+ if err := e.LastError(); err != nil {
return 0, err
}
- n, err := u.send()
+
+ udpInfo, err := e.prepareForWrite(p, opts)
if err != nil {
return 0, err
}
- return int64(n), nil
+ defer udpInfo.ctx.Release()
+
+ pktInfo := udpInfo.ctx.PacketInfo()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength),
+ Data: udpInfo.data.ToVectorisedView(),
+ })
+
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
+
+ length := uint16(pkt.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: udpInfo.localPort,
+ DstPort: udpInfo.remotePort,
+ Length: length,
+ })
+
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if pktInfo.RequiresTXTransportChecksum &&
+ (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
+ udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine(
+ header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
+ pkt.Data().AsRange().Checksum(),
+ )))
+ }
+ if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ e.stack.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
+ }
+
+ // Track count of packets sent.
+ e.stack.Stats().UDP.PacketsSent.Increment()
+ return int64(len(udpInfo.data)), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -601,36 +489,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.MTUDiscoverOption:
- // Return not supported if the value is not disabling path
- // MTU discovery.
- if v != tcpip.PMTUDiscoveryDont {
- return &tcpip.ErrNotSupported{}
- }
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- e.multicastTTL = uint8(v)
- e.mu.Unlock()
-
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- }
-
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
@@ -642,145 +501,12 @@ func (e *endpoint) HasNIC(id int32) bool {
// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
-
- fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
- if err != nil {
- return err
- }
- nic := v.NIC
- addr := fa.Addr
-
- if nic == 0 && addr == "" {
- e.multicastAddr = ""
- e.multicastNICID = 0
- break
- }
-
- if nic != 0 {
- if !e.stack.CheckNIC(nic) {
- return &tcpip.ErrBadLocalAddress{}
- }
- } else {
- nic = e.stack.CheckLocalAddress(0, netProto, addr)
- if nic == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
- }
-
- if e.BindNICID != 0 && e.BindNICID != nic {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- e.multicastNICID = nic
- e.multicastAddr = addr
-
- case *tcpip.AddMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
-
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToInsert]; ok {
- return &tcpip.ErrPortInUse{}
- }
-
- if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- e.multicastMemberships[memToInsert] = struct{}{}
-
- case *tcpip.RemoveMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToRemove]; !ok {
- return &tcpip.ErrBadLocalAddress{}
- }
-
- if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- delete(e.multicastMemberships, memToRemove)
-
- case *tcpip.SocketDetachFilterOption:
- return nil
- }
- return nil
+ return e.net.SetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
- case tcpip.IPv4TOSOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.MTUDiscoverOption:
- // The only supported setting is path MTU discovery disabled.
- return tcpip.PMTUDiscoveryDont, nil
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- v := int(e.multicastTTL)
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -791,108 +517,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.mu.Lock()
- v := int(e.ttl)
- e.mu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- *o = tcpip.MulticastInterfaceOption{
- NIC: e.multicastNICID,
- InterfaceAddr: e.multicastAddr,
- }
- e.mu.Unlock()
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
- return nil
+ return e.net.GetSockOpt(opt)
}
-// udpPacketInfo contains all information required to send a UDP packet.
-//
-// This should be used as a value-only type, which exists in order to simplify
-// return value syntax. It should not be exported or extended.
+// udpPacketInfo holds information needed to send a UDP packet.
type udpPacketInfo struct {
- route *stack.Route
- data buffer.View
- localPort uint16
- remotePort uint16
- ttl uint8
- useDefaultTTL bool
- tos uint8
- owner tcpip.PacketOwner
- noChecksum bool
-}
-
-// send sends the given packet.
-func (u *udpPacketInfo) send() (int, tcpip.Error) {
- vv := u.data.ToVectorisedView()
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()),
- Data: vv,
- })
- pkt.Owner = u.owner
-
- // Initialize the UDP header.
- udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- pkt.TransportProtocolNumber = ProtocolNumber
-
- length := uint16(pkt.Size())
- udp.Encode(&header.UDPFields{
- SrcPort: u.localPort,
- DstPort: u.remotePort,
- Length: length,
- })
-
- // Set the checksum field unless TX checksum offload is enabled.
- // On IPv4, UDP checksum is optional, and a zero value indicates the
- // transmitter skipped the checksum generation (RFC768).
- // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if u.route.RequiresTXTransportChecksum() &&
- (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) {
- xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range vv.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udp.SetChecksum(^udp.CalculateChecksum(xsum))
- }
-
- if u.useDefaultTTL {
- u.ttl = u.route.DefaultTTL()
- }
- if err := u.route.WritePacket(stack.NetworkHeaderParams{
- Protocol: ProtocolNumber,
- TTL: u.ttl,
- TOS: u.tos,
- }, pkt); err != nil {
- u.route.Stats().UDP.PacketSendErrors.Increment()
- return 0, err
- }
-
- // Track count of packets sent.
- u.route.Stats().UDP.PacketsSent.Increment()
- return len(u.data), nil
-}
-
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
+ ctx network.WriteContext
+ data buffer.View
+ localPort uint16
+ remotePort uint16
}
// Disconnect implements tcpip.Endpoint.
@@ -900,7 +540,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.EndpointState() != StateConnected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return nil
}
var (
@@ -913,26 +553,28 @@ func (e *endpoint) Disconnect() tcpip.Error {
boundPortFlags := e.boundPortFlags
// Exclude ephemerally bound endpoints.
- if e.BindNICID != 0 || e.ID.LocalAddress == "" {
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ if e.net.WasBound() {
var err tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.ID.LocalPort,
- LocalAddress: e.ID.LocalAddress,
+ LocalPort: info.ID.LocalPort,
+ LocalAddress: info.ID.LocalAddress,
}
id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
return err
}
- e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
- if e.ID.LocalPort != 0 {
+ if info.ID.LocalPort != 0 {
// Release the ephemeral port.
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: info.ID.LocalAddress,
+ Port: info.ID.LocalPort,
Flags: boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -940,15 +582,14 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
- e.setEndpointState(StateInitial)
}
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
- e.ID = id
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
e.boundBindToDevice = btd
- e.route.Release()
- e.route = nil
- e.dstPort = 0
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+
+ e.net.Disconnect()
return nil
}
@@ -958,88 +599,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- var localPort uint16
- switch e.EndpointState() {
- case StateInitial:
- case StateBound, StateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
-
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.localPort
+ nextID.RemotePort = addr.Port
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
+ oldPortFlags := e.boundPortFlags
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: e.ID.LocalAddress,
- LocalPort: localPort,
- RemotePort: addr.Port,
- RemoteAddress: r.RemoteAddress(),
- }
-
- if e.EndpointState() == StateInitial {
- id.LocalAddress = r.LocalAddress()
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv4ProtocolNumber,
- header.IPv6ProtocolNumber,
+ nextID, btd, err := e.registerWithStack(netProtos, nextID)
+ if err != nil {
+ return err
}
- }
- oldPortFlags := e.boundPortFlags
+ // Remove the old registration.
+ if e.localPort != 0 {
+ previousID.LocalPort = e.localPort
+ previousID.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
+ }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = nextID.LocalPort
+ e.remotePort = nextID.RemotePort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- // Remove the old registration.
- if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
- }
-
- e.ID = id
- e.boundBindToDevice = btd
- if e.route != nil {
- // If the endpoint was already connected then make sure we release the
- // previous route.
- e.route.Release()
- }
- e.route = r
- e.dstPort = addr.Port
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- e.setEndpointState(StateConnected)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1054,15 +655,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // A socket in the bound state can still receive multicast messages,
- // so we need to notify waiters on shutdown.
- if state := e.EndpointState(); state != StateBound && state != StateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- e.shutdownFlags |= flags
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
+ }
if flags&tcpip.ShutdownRead != 0 {
+ e.readShutdown = true
+
e.rcvMu.Lock()
wasClosed := e.rcvClosed
e.rcvClosed = true
@@ -1088,7 +697,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- if e.ID.LocalPort == 0 {
+ if e.localPort == 0 {
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
@@ -1126,56 +735,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
+ if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
}
- }
- nicID := addr.NIC
- if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
- // A local unicast address was specified, verify that it's valid.
- nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicID == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: boundAddr,
+ }
+ id, btd, err := e.registerWithStack(netProtos, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = id.LocalPort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.boundBindToDevice = btd
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- // Mark endpoint as bound.
- e.setEndpointState(StateBound)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1190,9 +786,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
return err
}
- // Save the effective NICID generated by bindLocked.
- e.BindNICID = e.RegisterNICID
-
return nil
}
@@ -1201,16 +794,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.ID.LocalAddress
- if e.EndpointState() == StateConnected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.localPort
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -1218,15 +804,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.EndpointState() != StateConnected || e.dstPort == 0 {
+ addr, connected := e.net.GetRemoteAddress()
+ if !connected || e.remotePort == 0 {
return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ addr.Port = e.remotePort
+ return addr, nil
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -1321,6 +905,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// Push new packet into receive list and increment the buffer size.
packet := &udpPacket{
+ netProto: pkt.NetworkProtocolNumber,
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
@@ -1376,19 +961,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
payload = udp.Payload()
}
+ id := e.net.Info().ID
e.SocketOptions().QueueErr(&tcpip.SockError{
Err: err,
Cause: transErr,
Payload: payload,
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: id.RemoteAddress,
+ Port: e.remotePort,
},
Offender: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: e.localPort,
},
NetProto: pkt.NetworkProtocolNumber,
})
@@ -1403,7 +989,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// TODO(gvisor.dev/issues/5270): Handle all transport errors.
switch transErr.Kind() {
case stack.DestinationPortUnreachableTransportError:
- if e.EndpointState() == StateConnected {
+ if e.net.State() == transport.DatagramEndpointStateConnected {
e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
}
}
@@ -1411,16 +997,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
- return uint32(e.EndpointState())
+ return uint32(e.net.State())
}
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
- return &ret
+ defer e.mu.RUnlock()
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ return &info
}
// Stats returns a pointer to the endpoint stats.
@@ -1431,13 +1018,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
-func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
-}
-
// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// SocketOptions implements tcpip.Endpoint.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 1f638c3f6..2ff8b0482 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,12 +15,13 @@
package udp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -35,17 +36,11 @@ func (p *udpPacket) loadReceivedAt(nsec int64) {
// saveData saves udpPacket.data field.
func (p *udpPacket) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads udpPacket.data field.
func (p *udpPacket) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
@@ -66,50 +61,28 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.mu.Lock()
defer e.mu.Unlock()
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- for m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
-
- state := e.EndpointState()
- if state != StateBound && state != StateConnected {
- return
- }
-
- netProto := e.effectiveNetProtos[0]
- // Connect() and bindLocked() both assert
- //
- // netProto == header.IPv6ProtocolNumber
- //
- // before creating a multi-entry effectiveNetProtos.
- if len(e.effectiveNetProtos) > 1 {
- netProto = header.IPv6ProtocolNumber
- }
-
- var err tcpip.Error
- if state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ var err tcpip.Error
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
- // A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- // Our saved state had a port, but we don't actually have a
- // reservation. We need to remove the port from our state, but still
- // pass it to the reservation machinery.
- id := e.ID
- e.ID.LocalPort = 0
- e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
- if err != nil {
- panic(err)
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 7c357cb09..7238fc019 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
netHdr := r.pkt.Network()
- route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
- if err != nil {
+ if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil {
+ return nil, err
+ }
+
+ if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil {
return nil, err
}
- ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
- route.Release()
return nil, err
}
- ep.ID = r.id
- ep.route = route
- ep.dstPort = r.id.RemotePort
+ ep.localPort = r.id.LocalPort
+ ep.remotePort = r.id.RemotePort
ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
- ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
- ep.state = uint32(StateConnected)
-
ep.rcvMu.Lock()
ep.rcvReady = true
ep.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4008cacf2..b3199489c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -290,6 +291,7 @@ type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
+ nicID tcpip.NICID
ep tcpip.Endpoint
wq waiter.Queue
@@ -301,6 +303,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
}
func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
+ const nicID = 1
+
t.Helper()
options := stack.Options{
@@ -310,38 +314,50 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
Clock: &faketime.NullClock{},
}
s := stack.New(options)
+ // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus
+ // never allows ICMP messages.
+ s.SetICMPLimit(rate.Inf)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
+ if err := s.CreateNIC(nicID, wep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
{
Destination: header.IPv6EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
})
return &testContext{
t: t,
s: s,
+ nicID: nicID,
linkEP: ep,
}
}
@@ -1353,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
func TestReadIPPacketInfo(t *testing.T) {
tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedLocalAddr tcpip.Address
- expectedDestAddr tcpip.Address
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ checker func(tcpip.NICID) checker.ControlMessagesChecker
}{
{
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedLocalAddr: stackAddr,
- expectedDestAddr: stackAddr,
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ LocalAddr: stackAddr,
+ DestinationAddr: stackAddr,
+ })
+ },
},
{
name: "IPv4 multicast",
proto: header.IPv4ProtocolNumber,
flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastAddr,
- expectedDestAddr: multicastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: multicastAddr,
+ DestinationAddr: multicastAddr,
+ })
+ },
},
{
name: "IPv4 broadcast",
proto: header.IPv4ProtocolNumber,
flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: broadcastAddr,
- expectedDestAddr: broadcastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: broadcastAddr,
+ DestinationAddr: broadcastAddr,
+ })
+ },
},
{
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedLocalAddr: stackV6Addr,
- expectedDestAddr: stackV6Addr,
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: stackV6Addr,
+ })
+ },
},
{
name: "IPv6 multicast",
proto: header.IPv6ProtocolNumber,
flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastV6Addr,
- expectedDestAddr: multicastV6Addr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: multicastV6Addr,
+ })
+ },
},
}
@@ -1433,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- c.ep.SocketOptions().SetReceivePacketInfo(true)
+ switch f := test.flow.netProto(); f {
+ case header.IPv4ProtocolNumber:
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
+ case header.IPv6ProtocolNumber:
+ c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true)
+ default:
+ t.Fatalf("unhandled protocol number = %d", f)
+ }
- testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: 1,
- LocalAddr: test.expectedLocalAddr,
- DestinationAddr: test.expectedDestAddr,
- }))
+ testRead(c, test.flow, test.checker(c.nicID))
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
@@ -1644,8 +1669,10 @@ func TestSetTTL(t *testing.T) {
}
}
+var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6}
+
func TestSetTOS(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ for _, flow := range v4PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1680,8 +1707,10 @@ func TestSetTOS(t *testing.T) {
}
}
+var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6}
+
func TestSetTClass(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ for _, flow := range v6PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1725,8 +1754,14 @@ func TestReceiveTosTClass(t *testing.T) {
name string
tests []testFlow
}{
- {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
- {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {
+ name: RcvTOSOpt,
+ tests: v4PacketFlows[:],
+ },
+ {
+ name: RcvTClassOpt,
+ tests: v6PacketFlows[:],
+ },
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1737,6 +1772,14 @@ func TestReceiveTosTClass(t *testing.T) {
c.createEndpointForFlow(flow)
name := testCase.name
+ if flow.isMulticast() {
+ netProto := flow.netProto()
+ addr := flow.getMcastAddr()
+ if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil {
+ c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err)
+ }
+ }
+
var optionGetter func() bool
var optionSetter func(bool)
switch name {
@@ -2482,8 +2525,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)