summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go112
-rw-r--r--pkg/tcpip/checker/checker.go2
-rw-r--r--pkg/tcpip/header/checksum.go62
-rw-r--r--pkg/tcpip/header/checksum_test.go203
-rw-r--r--pkg/tcpip/header/eth.go6
-rw-r--r--pkg/tcpip/header/eth_test.go4
-rw-r--r--pkg/tcpip/header/interfaces.go38
-rw-r--r--pkg/tcpip/header/ipv4.go12
-rw-r--r--pkg/tcpip/header/ndp_options.go150
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go94
-rw-r--r--pkg/tcpip/header/ndp_test.go339
-rw-r--r--pkg/tcpip/header/tcp.go29
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/internal/tcp/BUILD12
-rw-r--r--pkg/tcpip/internal/tcp/tcp.go48
-rw-r--r--pkg/tcpip/link/channel/channel.go24
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go28
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go243
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go1
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go22
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go1
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go10
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go54
-rw-r--r--pkg/tcpip/link/loopback/loopback.go31
-rw-r--r--pkg/tcpip/link/muxed/injectable.go5
-rw-r--r--pkg/tcpip/link/nested/nested.go15
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/pipe/pipe.go3
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go17
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s2
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s2
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go1
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go6
-rw-r--r--pkg/tcpip/link/rawfile/errors.go1
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go1
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go131
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go1
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go4
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go1
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go59
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go47
-rw-r--r--pkg/tcpip/link/tun/BUILD2
-rw-r--r--pkg/tcpip/link/tun/device.go36
-rw-r--r--pkg/tcpip/link/tun/tun_unsafe.go1
-rw-r--r--pkg/tcpip/link/waitable/waitable.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go5
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go10
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go5
-rw-r--r--pkg/tcpip/network/ip_test.go215
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go5
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go28
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go32
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go85
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go72
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go33
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go92
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go33
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go293
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go61
-rw-r--r--pkg/tcpip/network/multicast_group_test.go21
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go40
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go9
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go17
-rw-r--r--pkg/tcpip/socketops.go98
-rw-r--r--pkg/tcpip/stack/BUILD3
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go26
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go4
-rw-r--r--pkg/tcpip/stack/conntrack.go67
-rw-r--r--pkg/tcpip/stack/forwarding_test.go30
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go105
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go1894
-rw-r--r--pkg/tcpip/stack/nic.go169
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go14
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go26
-rw-r--r--pkg/tcpip/stack/registration.go52
-rw-r--r--pkg/tcpip/stack/stack.go132
-rw-r--r--pkg/tcpip/stack/stack_test.go690
-rw-r--r--pkg/tcpip/stack/tcp.go12
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go73
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go68
-rw-r--r--pkg/tcpip/stack/transport_test.go33
-rw-r--r--pkg/tcpip/tcpip.go14
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go20
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go72
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go24
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go58
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go20
-rw-r--r--pkg/tcpip/tests/integration/route_test.go38
-rw-r--r--pkg/tcpip/tests/utils/utils.go32
-rw-r--r--pkg/tcpip/transport/BUILD13
-rw-r--r--pkg/tcpip/transport/datagram.go49
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go21
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go8
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD45
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go811
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go58
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go318
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go219
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go18
-rw-r--r--pkg/tcpip/transport/raw/BUILD2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go445
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go30
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/accept.go68
-rw-r--r--pkg/tcpip/transport/tcp/connect.go229
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go295
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go1
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go9
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go54
-rw-r--r--pkg/tcpip/transport/tcp/rack.go3
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go21
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go1
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go64
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go15
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go810
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go91
-rw-r--r--pkg/tcpip/transport/transport.go16
-rw-r--r--pkg/tcpip/transport/udp/BUILD2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go900
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go65
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go21
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go58
139 files changed, 7531 insertions, 4081 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index ed4d7e958..dbe4506cc 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -46,7 +46,6 @@ deps_test(
"//pkg/gohacks",
"//pkg/goid",
"//pkg/ilist",
- "//pkg/iovec",
"//pkg/linewriter",
"//pkg/log",
"//pkg/rand",
@@ -70,7 +69,6 @@ deps_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/packetsocket",
"//pkg/tcpip/link/qdisc/fifo",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 48b24692b..c8460e63c 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
done := make(chan struct{})
@@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
done := make(chan struct{})
fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
@@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %w", NICID, protocolAddr, err)
+ }
l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ return nil, nil, nil, fmt.Errorf("NewListener: %w", err)
}
c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ return nil, nil, nil, fmt.Errorf("DialTCP: %w", err)
}
c2, err = l.Accept()
if err != nil {
l.Close()
c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Accept: %w", err)
}
stop = func() {
@@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
if err := l.Close(); err != nil {
stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Close(): %w", err)
}
return c1, c2, stop, nil
@@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
@@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
time.Sleep(time.Second)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index e0dfe5813..2f34bf8dd 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -729,7 +729,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
return
}
l := int(opts[i+1])
- if i < 2 || i+l > limit {
+ if l < 2 || i+l > limit {
return
}
i += l
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 6aa9acfa8..e2c85e220 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -18,6 +18,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
+
+// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
+// checksum.
+//
+// The value MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ return ChecksumCombine(xsum, ChecksumCombine(new, ^old))
+}
+
+// checksumUpdate2ByteAlignedAddress updates an address in a calculated
+// checksum.
+//
+// The addresses must have the same length and must contain an even number
+// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
+ const uint16Bytes = 2
+
+ if len(old) != len(new) {
+ panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new)))
+ }
+
+ if len(old)%uint16Bytes != 0 {
+ panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old)))
+ }
+
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ for len(old) != 0 {
+ // Convert the 2 byte sequences to uint16 values then apply the increment
+ // update.
+ xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1]))
+ old = old[uint16Bytes:]
+ new = new[uint16Bytes:]
+ }
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index d267dabd0..3445511f4 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -23,6 +23,7 @@ import (
"sync"
"testing"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) {
})
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
+
+func randomAddress(size int) tcpip.Address {
+ s := make([]byte, size)
+ for i := 0; i < size; i++ {
+ s[i] = byte(rand.Uint32())
+ }
+ return tcpip.Address(s)
+}
+
+func TestChecksummableNetworkUpdateAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ update func(header.IPv4, tcpip.Address)
+ }{
+ {
+ name: "SetSourceAddressWithChecksumUpdate",
+ update: header.IPv4.SetSourceAddressWithChecksumUpdate,
+ },
+ {
+ name: "SetDestinationAddressWithChecksumUpdate",
+ update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for i := 0; i < 1000; i++ {
+ var origBytes [header.IPv4MinimumSize]byte
+ header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
+ TOS: 1,
+ TotalLength: header.IPv4MinimumSize,
+ ID: 2,
+ Flags: 3,
+ FragmentOffset: 4,
+ TTL: 5,
+ Protocol: 6,
+ Checksum: 0,
+ SrcAddr: randomAddress(header.IPv4AddressSize),
+ DstAddr: randomAddress(header.IPv4AddressSize),
+ })
+
+ addr := randomAddress(header.IPv4AddressSize)
+
+ bytesCopy := origBytes
+ h := header.IPv4(bytesCopy[:])
+ origXSum := h.CalculateChecksum()
+ h.SetChecksum(^origXSum)
+
+ test.update(h, addr)
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+ want := h.CalculateChecksum()
+ if got != want {
+ t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
+ }
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePort(t *testing.T) {
+ // The fields in the pseudo header is not tested here so we just use 0.
+ const pseudoHeaderXSum = 0
+
+ tests := []struct {
+ name string
+ transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.TCP(make([]byte, header.TCPMinimumSize))
+ h.Encode(&header.TCPFields{
+ SrcPort: src,
+ DstPort: dst,
+ SeqNum: 1,
+ AckNum: 2,
+ DataOffset: header.TCPMinimumSize,
+ Flags: 3,
+ WindowSize: 4,
+ Checksum: 0,
+ UrgentPointer: 5,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.UDP(make([]byte, header.UDPMinimumSize))
+ h.Encode(&header.UDPFields{
+ SrcPort: src,
+ DstPort: dst,
+ Length: 0,
+ Checksum: 0,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ origSrcPort := uint16(rand.Uint32())
+ origDstPort := uint16(rand.Uint32())
+ newPort := uint16(rand.Uint32())
+
+ t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range []struct {
+ name string
+ update func(header.ChecksummableTransport)
+ }{
+ {
+ name: "Source port",
+ update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
+ },
+ {
+ name: "Destination port",
+ update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
+ },
+ } {
+ t.Run(subTest.name, func(t *testing.T) {
+ h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
+ subTest.update(h)
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+
+ if want := calcXSum(pseudoHeaderXSum); got != want {
+ h, _ := test.transportHdr(origSrcPort, origDstPort)
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
+ const addressSize = 6
+
+ tests := []struct {
+ name string
+ transportHdr func() header.ChecksummableTransport
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ permanent := randomAddress(addressSize)
+ old := randomAddress(addressSize)
+ new := randomAddress(addressSize)
+
+ t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, fullChecksum := range []bool{true, false} {
+ t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
+ initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
+ if fullChecksum {
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ initialXSum = ^initialXSum
+ }
+
+ h := test.transportHdr()
+ h.SetChecksum(initialXSum)
+ h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
+
+ got := h.Checksum()
+ if fullChecksum {
+ got = ^got
+ }
+ if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 95ade0e5c..1f18213e5 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -49,9 +49,9 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
- // unspecifiedEthernetAddress is the unspecified ethernet address
+ // UnspecifiedEthernetAddress is the unspecified ethernet address
// (all bits set to 0).
- unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
// EthernetBroadcastAddress is an ethernet address that addresses every node
// on a local link.
@@ -134,7 +134,7 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
return false
}
- if addr == unspecifiedEthernetAddress {
+ if addr == UnspecifiedEthernetAddress {
return false
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index bf9ccbf1a..adc04e855 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -44,7 +44,7 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
@@ -91,7 +91,7 @@ func TestIsMulticastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
index 861cbbb70..3a41adfc4 100644
--- a/pkg/tcpip/header/interfaces.go
+++ b/pkg/tcpip/header/interfaces.go
@@ -53,6 +53,31 @@ type Transport interface {
Payload() []byte
}
+// ChecksummableTransport is a Transport that supports checksumming.
+type ChecksummableTransport interface {
+ Transport
+
+ // SetSourcePortWithChecksumUpdate sets the source port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetSourcePortWithChecksumUpdate(port uint16)
+
+ // SetDestinationPortWithChecksumUpdate sets the destination port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetDestinationPortWithChecksumUpdate(port uint16)
+
+ // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
+ // updated address in the pseudo header.
+ //
+ // If fullChecksum is true, the receiver's checksum field is assumed to hold a
+ // fully calculated checksum. Otherwise, it is assumed to hold a partially
+ // calculated checksum which only reflects the pseudo header.
+ UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
+}
+
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
@@ -90,3 +115,16 @@ type Network interface {
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
+
+// ChecksummableNetwork is a Network that supports checksumming.
+type ChecksummableNetwork interface {
+ Network
+
+ // SetSourceAddressAndChecksum sets the source address and updates the
+ // checksum to reflect the new address.
+ SetSourceAddressWithChecksumUpdate(tcpip.Address)
+
+ // SetDestinationAddressAndChecksum sets the destination address and
+ // updates the checksum to reflect the new address.
+ SetDestinationAddressWithChecksumUpdate(tcpip.Address)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e9abbb709..dcc549c7b 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new))
+ b.SetSourceAddress(new)
+}
+
+// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new))
+ b.SetDestinationAddress(new)
+}
+
// padIPv4OptionsLength returns the total length for IPv4 options of length l
// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index d6cad3a94..a647ea968 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -148,15 +148,10 @@ const (
// NDP option. That is, the length field for NDP options is in units of
// 8 octets, as per RFC 4861 section 4.6.
lengthByteUnits = 8
-)
-var (
// NDPInfiniteLifetime is a value that represents infinity for the
// 4-byte lifetime fields found in various NDP options. Its value is
// (2^32 - 1)s = 4294967295s.
- //
- // This is a variable instead of a constant so that tests can change
- // this value to a smaller value. It should only be modified by tests.
NDPInfiniteLifetime = time.Second * math.MaxUint32
)
@@ -238,6 +233,17 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
case ndpNonceOptionType:
return NDPNonceOption(body), false, nil
+ case ndpRouteInformationType:
+ if numBodyBytes > ndpRouteInformationMaxLength {
+ return nil, true, fmt.Errorf("got %d bytes for NDP Route Information option's body, expected at max %d bytes: %w", numBodyBytes, ndpRouteInformationMaxLength, ErrNDPOptMalformedBody)
+ }
+ opt := NDPRouteInformation(body)
+ if err := opt.hasError(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
case ndpPrefixInformationType:
// Make sure the length of a Prefix Information option
// body is ndpPrefixInformationLength, as per RFC 4861
@@ -935,3 +941,137 @@ func isUpperLetter(b byte) bool {
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
+
+// As per RFC 4191 section 2.3,
+//
+// 2.3. Route Information Option
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Length | Prefix Length |Resvd|Prf|Resvd|
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Route Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Prefix (Variable Length) |
+// . .
+// . .
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
+// Fields:
+//
+// Type 24
+//
+//
+// Length 8-bit unsigned integer. The length of the option
+// (including the Type and Length fields) in units of 8
+// octets. The Length field is 1, 2, or 3 depending on the
+// Prefix Length. If Prefix Length is greater than 64, then
+// Length must be 3. If Prefix Length is greater than 0,
+// then Length must be 2 or 3. If Prefix Length is zero,
+// then Length must be 1, 2, or 3.
+const (
+ ndpRouteInformationType = ndpOptionIdentifier(24)
+ ndpRouteInformationMaxLength = 22
+
+ ndpRouteInformationPrefixLengthIdx = 0
+ ndpRouteInformationFlagsIdx = 1
+ ndpRouteInformationPrfShift = 3
+ ndpRouteInformationPrfMask = 3 << ndpRouteInformationPrfShift
+ ndpRouteInformationRouteLifetimeIdx = 2
+ ndpRouteInformationRoutePrefixIdx = 6
+)
+
+// NDPRouteInformation is the NDP Router Information option, as defined by
+// RFC 4191 section 2.3.
+type NDPRouteInformation []byte
+
+func (NDPRouteInformation) kind() ndpOptionIdentifier {
+ return ndpRouteInformationType
+}
+
+func (o NDPRouteInformation) length() int {
+ return len(o)
+}
+
+func (o NDPRouteInformation) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.
+func (o NDPRouteInformation) String() string {
+ return fmt.Sprintf("%T", o)
+}
+
+// PrefixLength returns the length of the prefix.
+func (o NDPRouteInformation) PrefixLength() uint8 {
+ return o[ndpRouteInformationPrefixLengthIdx]
+}
+
+// RoutePreference returns the preference of the route over other routes to the
+// same destination but through a different router.
+func (o NDPRouteInformation) RoutePreference() NDPRoutePreference {
+ return NDPRoutePreference((o[ndpRouteInformationFlagsIdx] & ndpRouteInformationPrfMask) >> ndpRouteInformationPrfShift)
+}
+
+// RouteLifetime returns the lifetime of the route.
+//
+// Note, a value of 0 implies the route is now invalid and a value of
+// infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPRouteInformation) RouteLifetime() time.Duration {
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRouteInformationRouteLifetimeIdx:]))
+}
+
+// Prefix returns the prefix of the destination subnet this route is for.
+func (o NDPRouteInformation) Prefix() (tcpip.Subnet, error) {
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return tcpip.Subnet{}, fmt.Errorf("got prefix length = %d, want <= %d", prefixLength, max)
+ }
+
+ prefix := o[ndpRouteInformationRoutePrefixIdx:]
+ var addrBytes [IPv6AddressSize]byte
+ if n := copy(addrBytes[:], prefix); n != len(prefix) {
+ panic(fmt.Sprintf("got copy(addrBytes, prefix) = %d, want = %d", n, len(prefix)))
+ }
+
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes[:]),
+ PrefixLen: prefixLength,
+ }.Subnet(), nil
+}
+
+func (o NDPRouteInformation) hasError() error {
+ l := len(o)
+ if l < ndpRouteInformationRoutePrefixIdx {
+ return fmt.Errorf("%T too small, got = %d bytes: %w", o, l, ErrNDPOptMalformedBody)
+ }
+
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return fmt.Errorf("got prefix length = %d, want <= %d: %w", prefixLength, max, ErrNDPOptMalformedBody)
+ }
+
+ // Length 8-bit unsigned integer. The length of the option
+ // (including the Type and Length fields) in units of 8
+ // octets. The Length field is 1, 2, or 3 depending on the
+ // Prefix Length. If Prefix Length is greater than 64, then
+ // Length must be 3. If Prefix Length is greater than 0,
+ // then Length must be 2 or 3. If Prefix Length is zero,
+ // then Length must be 1, 2, or 3.
+ l += 2 // Add 2 bytes for the type and length bytes.
+ lengthField := l / lengthByteUnits
+ if prefixLength > 64 {
+ if lengthField != 3 {
+ return fmt.Errorf("Length field must be 3 when Prefix Length (%d) is > 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if prefixLength > 0 {
+ if lengthField != 2 && lengthField != 3 {
+ return fmt.Errorf("Length field must be 2 or 3 when Prefix Length (%d) is between 0 and 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if lengthField == 0 || lengthField > 3 {
+ return fmt.Errorf("Length field must be 1, 2, or 3 when Prefix Length is zero (got = %d): %w", lengthField, ErrNDPOptMalformedBody)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
index bf7610863..7d6efa083 100644
--- a/pkg/tcpip/header/ndp_router_advert.go
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -16,15 +16,94 @@ package header
import (
"encoding/binary"
+ "fmt"
"time"
)
+var _ fmt.Stringer = NDPRoutePreference(0)
+
+// NDPRoutePreference is the preference values for default routers or
+// more-specific routes.
+//
+// As per RFC 4191 section 2.1,
+//
+// Default router preferences and preferences for more-specific routes
+// are encoded the same way.
+//
+// Preference values are encoded as a two-bit signed integer, as
+// follows:
+//
+// 01 High
+// 00 Medium (default)
+// 11 Low
+// 10 Reserved - MUST NOT be sent
+//
+// Note that implementations can treat the value as a two-bit signed
+// integer.
+//
+// Having just three values reinforces that they are not metrics and
+// more values do not appear to be necessary for reasonable scenarios.
+type NDPRoutePreference uint8
+
+const (
+ // HighRoutePreference indicates a high preference, as per
+ // RFC 4191 section 2.1.
+ HighRoutePreference NDPRoutePreference = 0b01
+
+ // MediumRoutePreference indicates a medium preference, as per
+ // RFC 4191 section 2.1.
+ //
+ // This is the default preference value.
+ MediumRoutePreference = 0b00
+
+ // LowRoutePreference indicates a low preference, as per
+ // RFC 4191 section 2.1.
+ LowRoutePreference = 0b11
+
+ // ReservedRoutePreference is a reserved preference value, as per
+ // RFC 4191 section 2.1.
+ //
+ // It MUST NOT be sent.
+ ReservedRoutePreference = 0b10
+)
+
+// String implements fmt.Stringer.
+func (p NDPRoutePreference) String() string {
+ switch p {
+ case HighRoutePreference:
+ return "HighRoutePreference"
+ case MediumRoutePreference:
+ return "MediumRoutePreference"
+ case LowRoutePreference:
+ return "LowRoutePreference"
+ case ReservedRoutePreference:
+ return "ReservedRoutePreference"
+ default:
+ return fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+}
+
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
-// See RFC 4861 section 4.2 for more details.
+// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details.
type NDPRouterAdvert []byte
+// As per RFC 4191 section 2.2,
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Code | Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Reachable Time |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Retrans Timer |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options ...
+// +-+-+-+-+-+-+-+-+-+-+-+-
const (
// NDPRAMinimumSize is the minimum size of a valid NDP Router
// Advertisement message (body of an ICMPv6 packet).
@@ -47,6 +126,14 @@ const (
// within the bit-field/flags byte of an NDPRouterAdvert.
ndpRAOtherConfFlagMask = (1 << 6)
+ // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceShift = 3
+
+ // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift)
+
// ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
// field within an NDPRouterAdvert.
ndpRARouterLifetimeOffset = 2
@@ -80,6 +167,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
}
+// DefaultRouterPreference returns the Default Router Preference field.
+func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference {
+ return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift)
+}
+
// RouterLifetime returns the lifetime associated with the default router. A
// value of 0 means the source of the Router Advertisement is not a default
// router and SHOULD NOT appear on the default router list. Note, a value of 0
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 1b5093e58..2a897e938 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"regexp"
+ "strings"
"testing"
"time"
@@ -58,6 +59,224 @@ func TestNDPNeighborSolicit(t *testing.T) {
}
}
+func TestNDPRouteInformationOption(t *testing.T) {
+ tests := []struct {
+ name string
+
+ length uint8
+ prefixLength uint8
+ prf NDPRoutePreference
+ lifetimeS uint32
+ prefixBytes []byte
+ expectedPrefix tcpip.Subnet
+
+ expectedErr error
+ }{
+ {
+ name: "Length=1 with Prefix Length = 0",
+ length: 1,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=1 but Prefix Length > 0",
+ length: 1,
+ prefixLength: 1,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=2 with Prefix Length = 0",
+ length: 2,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (1)",
+ length: 2,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (64)",
+ length: 2,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length > 64",
+ length: 2,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=3 with Prefix Length = 0",
+ length: 3,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (1)",
+ length: 3,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (64)",
+ length: 3,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (65)",
+ length: 3,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 65,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (128)",
+ length: 3,
+ prefixLength: 128,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 128,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with (invalid) Prefix Length > 128",
+ length: 3,
+ prefixLength: 129,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ expectedRouteInformationBytes := [...]byte{
+ // Type, Length
+ 24, test.length,
+
+ // Prefix Length, Prf
+ uint8(test.prefixLength), uint8(test.prf) << 3,
+
+ // Route Lifetime
+ 0, 0, 0, 0,
+
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ }
+ binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS)
+ _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes)
+
+ opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits])
+ it, err := opts.Iter(false)
+ if err != nil {
+ t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
+ }
+ opt, done, err := it.Next()
+ if !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr)
+ }
+ if want := test.expectedErr != nil; done != want {
+ t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want)
+ }
+ if test.expectedErr != nil {
+ return
+ }
+
+ if got := opt.kind(); got != ndpRouteInformationType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType)
+ }
+
+ ri, ok := opt.(NDPRouteInformation)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPRouteInformation", opt)
+ }
+
+ if got := ri.PrefixLength(); got != test.prefixLength {
+ t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength)
+ }
+ if got := ri.RoutePreference(); got != test.prf {
+ t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf)
+ }
+ if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want {
+ t.Errorf("got RouteLifetime() = %s, want = %s", got, want)
+ }
+ if got, err := ri.Prefix(); err != nil {
+ t.Errorf("Prefix(): %s", err)
+ } else if got != test.expectedPrefix {
+ t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix)
+ }
+
+ // Iterator should not return anything else.
+ {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next() = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
+ }
+ }
+ })
+ }
+}
+
// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
func TestNDPNeighborAdvert(t *testing.T) {
b := []byte{
@@ -126,36 +345,83 @@ func TestNDPNeighborAdvert(t *testing.T) {
}
func TestNDPRouterAdvert(t *testing.T) {
- b := []byte{
- 64, 128, 1, 2,
- 3, 4, 5, 6,
- 7, 8, 9, 10,
+ tests := []struct {
+ hopLimit uint8
+ managedFlag, otherConfFlag bool
+ prf NDPRoutePreference
+ routerLifetimeS uint16
+ reachableTimeMS, retransTimerMS uint32
+ }{
+ {
+ hopLimit: 1,
+ managedFlag: false,
+ otherConfFlag: true,
+ prf: HighRoutePreference,
+ routerLifetimeS: 2,
+ reachableTimeMS: 3,
+ retransTimerMS: 4,
+ },
+ {
+ hopLimit: 64,
+ managedFlag: true,
+ otherConfFlag: false,
+ prf: LowRoutePreference,
+ routerLifetimeS: 258,
+ reachableTimeMS: 78492,
+ retransTimerMS: 13213,
+ },
}
- ra := NDPRouterAdvert(b)
+ for i, test := range tests {
+ t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
+ flags := uint8(0)
+ if test.managedFlag {
+ flags |= 1 << 7
+ }
+ if test.otherConfFlag {
+ flags |= 1 << 6
+ }
+ flags |= uint8(test.prf) << 3
+
+ b := []byte{
+ test.hopLimit, flags, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+ binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS)
+ binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS)
+ binary.BigEndian.PutUint32(b[8:], test.retransTimerMS)
- if got := ra.CurrHopLimit(); got != 64 {
- t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
- }
+ ra := NDPRouterAdvert(b)
- if got := ra.ManagedAddrConfFlag(); !got {
- t.Errorf("got ManagedAddrConfFlag = false, want = true")
- }
+ if got := ra.CurrHopLimit(); got != test.hopLimit {
+ t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit)
+ }
- if got := ra.OtherConfFlag(); got {
- t.Errorf("got OtherConfFlag = true, want = false")
- }
+ if got := ra.ManagedAddrConfFlag(); got != test.managedFlag {
+ t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag)
+ }
- if got, want := ra.RouterLifetime(), time.Second*258; got != want {
- t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
- }
+ if got := ra.OtherConfFlag(); got != test.otherConfFlag {
+ t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag)
+ }
- if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
- t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
- }
+ if got := ra.DefaultRouterPreference(); got != test.prf {
+ t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf)
+ }
+
+ if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want {
+ t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want {
+ t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want)
+ }
- if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
- t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
+ if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want {
+ t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want)
+ }
+ })
}
}
@@ -1451,3 +1717,32 @@ func TestNDPOptionsIter(t *testing.T) {
t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
}
}
+
+func TestNDPRoutePreferenceStringer(t *testing.T) {
+ p := NDPRoutePreference(0)
+ for {
+ var wantStr string
+ switch p {
+ case 0b01:
+ wantStr = "HighRoutePreference"
+ case 0b00:
+ wantStr = "MediumRoutePreference"
+ case 0b11:
+ wantStr = "LowRoutePreference"
+ case 0b10:
+ wantStr = "ReservedRoutePreference"
+ default:
+ wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+
+ if gotStr := p.String(); gotStr != wantStr {
+ t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr)
+ }
+
+ p++
+ if p == 0 {
+ // Overflowed, we hit all values.
+ break
+ }
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 8dabe3354..a75e51a28 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
b.SetChecksum(^checksum)
}
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
+
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index ae9d167ff..f69d53314 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
+
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
diff --git a/pkg/tcpip/internal/tcp/BUILD b/pkg/tcpip/internal/tcp/BUILD
new file mode 100644
index 000000000..9ae258a0b
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tcp",
+ srcs = ["tcp.go"],
+ visibility = ["//pkg/tcpip:__subpackages__"],
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/internal/tcp/tcp.go b/pkg/tcpip/internal/tcp/tcp.go
new file mode 100644
index 000000000..0616d368c
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/tcp.go
@@ -0,0 +1,48 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains internal type definitions that are not expected to be
+// used by anyone else outside pkg/tcpip.
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// TSOffset is an offset applied to the value of the TSVal field in the TCP
+// Timestamp option.
+//
+// +stateify savable
+type TSOffset struct {
+ milliseconds uint32
+}
+
+// NewTSOffset creates a new TSOffset from milliseconds.
+func NewTSOffset(milliseconds uint32) TSOffset {
+ return TSOffset{
+ milliseconds: milliseconds,
+ }
+}
+
+// TSVal applies the offset to now and returns the timestamp in milliseconds.
+func (offset TSOffset) TSVal(now tcpip.MonotonicTime) uint32 {
+ return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + offset.milliseconds
+}
+
+// Elapsed calculates the elapsed time given now and the echoed back timestamp.
+func (offset TSOffset) Elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return time.Duration(offset.TSVal(now)-tsEcr) * time.Millisecond
+}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index f26c857eb..658557d62 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -28,7 +28,9 @@ import (
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Pkt *stack.PacketBuffer
+ Pkt *stack.PacketBuffer
+
+ // TODO(https://gvisor.dev/issue/6537): Remove these fields.
Proto tcpip.NetworkProtocolNumber
Route stack.RouteInfo
}
@@ -244,7 +246,10 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
Route: r,
}
- e.q.Write(p)
+ // Write returns false if the queue is full. A full queue is not an error
+ // from the perspective of a LinkEndpoint so we ignore Write's return
+ // value and always return nil from this method.
+ _ = e.q.Write(p)
return nil
}
@@ -290,3 +295,18 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ p := PacketInfo{
+ Pkt: pkt,
+ Proto: pkt.NetworkProtocolNumber,
+ }
+
+ // Write returns false if the queue is full. A full queue is not an error
+ // from the perspective of a LinkEndpoint so we ignore Write's return
+ // value and always return nil from this method.
+ _ = e.q.Write(p)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index b427c6170..8211a2031 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -42,6 +42,14 @@ type Endpoint struct {
nested.Endpoint
}
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ if l := e.Endpoint.LinkAddress(); len(l) != 0 {
+ return l
+ }
+ return header.UnspecifiedEthernetAddress
+}
+
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
@@ -57,18 +65,22 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP
// Capabilities implements stack.LinkEndpoint.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+ c := e.Endpoint.Capabilities()
+ if c&stack.CapabilityLoopback == 0 {
+ c |= stack.CapabilityResolutionRequired
+ }
+ return c
}
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
return e.Endpoint.WritePacket(r, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- linkAddr := e.Endpoint.LinkAddress()
+ linkAddr := e.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
@@ -83,7 +95,10 @@ func (e *Endpoint) MaxHeaderLength() uint16 {
}
// ARPHardwareType implements stack.LinkEndpoint.
-func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ if a := e.Endpoint.ARPHardwareType(); a != header.ARPHardwareNone {
+ return a
+ }
return header.ARPHardwareEther
}
@@ -97,3 +112,8 @@ func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkP
}
eth.Encode(&fields)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.Endpoint.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d971194e6..1d0163823 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,7 +14,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 735c28da1..058242f96 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package fdbased provides the implemention of data-link layer endpoints
@@ -44,7 +45,6 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -56,6 +56,7 @@ import (
// linkDispatcher reads packets from the link FD and dispatches them to the
// NetworkDispatcher.
type linkDispatcher interface {
+ stop()
dispatch() (bool, tcpip.Error)
}
@@ -138,6 +139,20 @@ type endpoint struct {
// gsoKind is the supported kind of GSO.
gsoKind stack.SupportedGSO
+
+ // maxSyscallHeaderBytes has the same meaning as
+ // Options.MaxSyscallHeaderBytes.
+ maxSyscallHeaderBytes uintptr
+
+ // writevMaxIovs is the maximum number of iovecs that may be passed to
+ // rawfile.NonBlockingWriteIovec, as possibly limited by
+ // maxSyscallHeaderBytes. (No analogous limit is defined for
+ // rawfile.NonBlockingSendMMsg, since in that case the maximum number of
+ // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch
+ // encounters a packet whose iovec count is limited by
+ // maxSyscallHeaderBytes, it falls back to writing the packet using writev
+ // via WritePacket.)
+ writevMaxIovs int
}
// Options specify the details about the fd-based endpoint to be created.
@@ -186,6 +201,11 @@ type Options struct {
// RXChecksumOffload if true, indicates that this endpoints capability
// set should include CapabilityRXChecksumOffload.
RXChecksumOffload bool
+
+ // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes
+ // of struct iovec, msghdr, and mmsghdr that may be passed by each host
+ // system call.
+ MaxSyscallHeaderBytes int
}
// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT
@@ -235,14 +255,25 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
+ if opts.MaxSyscallHeaderBytes < 0 {
+ return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative")
+ }
+
e := &endpoint{
- fds: opts.FDs,
- mtu: opts.MTU,
- caps: caps,
- closed: opts.ClosedFunc,
- addr: opts.Address,
- hdrSize: hdrSize,
- packetDispatchMode: opts.PacketDispatchMode,
+ fds: opts.FDs,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ packetDispatchMode: opts.PacketDispatchMode,
+ maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes),
+ writevMaxIovs: rawfile.MaxIovs,
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs {
+ e.writevMaxIovs = max
+ }
}
// Increment fanoutID to ensure that we don't re-use the same fanoutID for
@@ -351,16 +382,27 @@ func isSocketFD(fd int) (bool, error) {
// Attach launches the goroutine that reads packets from the file descriptor and
// dispatches them via the provided dispatcher.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
- // Link endpoints are not savable. When transportation endpoints are
- // saved, they stop sending outgoing packets and all incoming packets
- // are rejected.
- for i := range e.inboundDispatchers {
- e.wg.Add(1)
- go func(i int) { // S/R-SAFE: See above.
- e.dispatchLoop(e.inboundDispatchers[i])
- e.wg.Done()
- }(i)
+ // nil means the NIC is being removed.
+ if dispatcher == nil && e.dispatcher != nil {
+ for _, dispatcher := range e.inboundDispatchers {
+ dispatcher.stop()
+ }
+ e.Wait()
+ e.dispatcher = nil
+ return
+ }
+ if dispatcher != nil && e.dispatcher == nil {
+ e.dispatcher = dispatcher
+ // Link endpoints are not savable. When transportation endpoints are
+ // saved, they stop sending outgoing packets and all incoming packets
+ // are rejected.
+ for i := range e.inboundDispatchers {
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
+ }
}
}
@@ -463,6 +505,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
}
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
@@ -470,9 +515,8 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
}
- var builder iovec.Builder
-
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
+ var vnetHdrBuf []byte
if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
@@ -494,71 +538,123 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
+ vnetHdrBuf = vnetHdr.marshal()
+ }
- vnetHdrBuf := vnetHdr.marshal()
- builder.Add(vnetHdrBuf)
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > e.writevMaxIovs {
+ numIovecs = e.writevMaxIovs
}
- for _, v := range pkt.Views() {
- builder.Add(v)
+ // Allocate small iovec arrays on the stack.
+ var iovecsArr [8]unix.Iovec
+ iovecs := iovecsArr[:0]
+ if numIovecs > len(iovecsArr) {
+ iovecs = make([]unix.Iovec, 0, numIovecs)
}
- return rawfile.NonBlockingWriteIovec(fd, builder.Build())
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
+ }
+ return rawfile.NonBlockingWriteIovec(fd, iovecs)
}
-func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) {
+func (e *endpoint) sendBatch(batchFD int, pkts []*stack.PacketBuffer) (int, tcpip.Error) {
// Send a batch of packets through batchFD.
- mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
- for _, pkt := range batch {
- if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
- }
+ mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts))
+ packets := 0
+ for packets < len(pkts) {
+ mmsgHdrs := mmsgHdrsStorage
+ batch := pkts[packets:]
+ syscallHeaderBytes := uintptr(0)
+ for _, pkt := range batch {
+ if e.hdrSize > 0 {
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
- var vnetHdrBuf []byte
- if e.gsoKind == stack.HWGSOSupported {
- vnetHdr := virtioNetHdr{}
- if pkt.GSOOptions.Type != stack.GSONone {
- vnetHdr.hdrLen = uint16(pkt.HeaderSize())
- if pkt.GSOOptions.NeedsCsum {
- vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
- vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
- vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
- }
- if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
- switch pkt.GSOOptions.Type {
- case stack.GSOTCPv4:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
- case stack.GSOTCPv6:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
- default:
- panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ var vnetHdrBuf []byte
+ if e.gsoKind == stack.HWGSOSupported {
+ vnetHdr := virtioNetHdr{}
+ if pkt.GSOOptions.Type != stack.GSONone {
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
+ if pkt.GSOOptions.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
+ vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
+ }
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
+ switch pkt.GSOOptions.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ }
+ vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
- vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
+ vnetHdrBuf = vnetHdr.marshal()
}
- vnetHdrBuf = vnetHdr.marshal()
- }
- var builder iovec.Builder
- builder.Add(vnetHdrBuf)
- for _, v := range pkt.Views() {
- builder.Add(v)
- }
- iovecs := builder.Build()
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > rawfile.MaxIovs {
+ numIovecs = rawfile.MaxIovs
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec
+ if syscallHeaderBytes > e.maxSyscallHeaderBytes {
+ // We can't fit this packet into this call to sendmmsg().
+ // We could potentially do so if we reduced numIovecs
+ // further, but this might incur considerable extra
+ // copying. Leave it to the next batch instead.
+ break
+ }
+ }
- var mmsgHdr rawfile.MMsgHdr
- mmsgHdr.Msg.Iov = &iovecs[0]
- mmsgHdr.Msg.SetIovlen((len(iovecs)))
- mmsgHdrs = append(mmsgHdrs, mmsgHdr)
- }
+ // We can't easily allocate iovec arrays on the stack here since
+ // they will escape this loop iteration via mmsgHdrs.
+ iovecs := make([]unix.Iovec, 0, numIovecs)
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
+ }
- packets := 0
- for len(mmsgHdrs) > 0 {
- sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
- if err != nil {
- return packets, err
+ var mmsgHdr rawfile.MMsgHdr
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ mmsgHdr.Msg.SetIovlen(len(iovecs))
+ mmsgHdrs = append(mmsgHdrs, mmsgHdr)
+ }
+
+ if len(mmsgHdrs) == 0 {
+ // We can't fit batch[0] into a mmsghdr while staying under
+ // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the
+ // mmsghdr (by using writev) and re-buffer iovecs more aggressively
+ // if necessary (by using e.writevMaxIovs instead of
+ // rawfile.MaxIovs).
+ pkt := batch[0]
+ if err := e.WritePacket(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return packets, err
+ }
+ packets++
+ } else {
+ for len(mmsgHdrs) > 0 {
+ sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
}
- packets += sent
- mmsgHdrs = mmsgHdrs[sent:]
}
return packets, nil
@@ -676,8 +772,9 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti
unix.SetNonblock(fd, true)
return &InjectableEndpoint{endpoint: endpoint{
- fds: []int{fd},
- mtu: mtu,
- caps: capabilities,
+ fds: []int{fd},
+ mtu: mtu,
+ caps: capabilities,
+ writevMaxIovs: rawfile.MaxIovs,
}}
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 8aad338b6..eccd21579 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package fdbased
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
index df14eaad1..904393faa 100644
--- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package fdbased
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 5d698a5e9..3f516cab5 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build (linux && amd64) || (linux && arm64)
// +build linux,amd64 linux,arm64
package fdbased
@@ -113,6 +114,7 @@ func (t tPacketHdr) Payload() []byte {
// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
// See: mmap_amd64_unsafe.go for implementation details.
type packetMMapDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -128,18 +130,18 @@ type packetMMapDispatcher struct {
ringOffset int
}
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, bool, tcpip.Error) {
hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ stopped, errno := rawfile.BlockingPollUntilStopped(d.efd, d.fd, unix.POLLIN|unix.POLLERR)
+ if errno != 0 {
if errno == unix.EINTR {
continue
}
- return nil, rawfile.TranslateErrno(errno)
+ return nil, stopped, rawfile.TranslateErrno(errno)
+ }
+ if stopped {
+ return nil, true, nil
}
if hdr.tpStatus()&tpStatusCopy != 0 {
// This frame is truncated so skip it after flipping the
@@ -157,14 +159,14 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
// Release packet to kernel.
hdr.setTPStatus(tpStatusKernel)
d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
+ return pkt, false, nil
}
// dispatch reads packets from an mmaped ring buffer and dispatches them to the
// network stack.
func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
+ pkt, stopped, err := d.readMMappedPacket()
+ if err != nil || stopped {
return false, err
}
var (
diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go
index 67be52d67..9d8679502 100644
--- a/pkg/tcpip/link/fdbased/mmap_stub.go
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build !linux || (!amd64 && !arm64)
// +build !linux !amd64,!arm64
package fdbased
diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 1293f68a2..5b786169a 100644
--- a/pkg/tcpip/link/fdbased/mmap_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build (linux && amd64) || (linux && arm64)
// +build linux,amd64 linux,arm64
package fdbased
@@ -46,9 +47,14 @@ func (t tPacketHdr) setTPStatus(status uint32) {
}
func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
d := &packetMMapDispatcher{
- fd: fd,
- e: e,
+ stopFd: stopFd,
+ fd: fd,
+ e: e,
}
pageSize := unix.Getpagesize()
if tpBlockSize%pageSize != 0 {
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 4b7ef3aac..fab34c5fa 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package fdbased
import (
+ "fmt"
+
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -113,9 +116,36 @@ func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView {
return buffer.NewVectorisedView(n, views)
}
+// stopFd is an eventfd used to signal the stop of a dispatcher.
+type stopFd struct {
+ efd int
+}
+
+func newStopFd() (stopFd, error) {
+ efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
+ if err != nil {
+ return stopFd{efd: -1}, fmt.Errorf("failed to create eventfd: %w", err)
+ }
+ return stopFd{efd: efd}, nil
+}
+
+// stop writes to the eventfd and notifies the dispatcher to stop. It does not
+// block.
+func (s *stopFd) stop() {
+ increment := []byte{1, 0, 0, 0, 0, 0, 0, 0}
+ if n, err := unix.Write(s.efd, increment); n != len(increment) || err != nil {
+ // There are two possible errors documented in eventfd(2) for writing:
+ // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL.
+ // 2. stop is only supposed to be called once, it can't reach the limit,
+ // thus no EAGAIN.
+ panic(fmt.Sprintf("write(efd) = (%d, %s), want (%d, nil)", n, err, len(increment)))
+ }
+}
+
// readVDispatcher uses readv() system call to read inbound packets and
// dispatches them.
type readVDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -127,7 +157,15 @@ type readVDispatcher struct {
}
func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- d := &readVDispatcher{fd: fd, e: e}
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
+ d := &readVDispatcher{
+ stopFd: stopFd,
+ fd: fd,
+ e: e,
+ }
skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
return d, nil
@@ -135,8 +173,8 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
// dispatch reads one packet from the file descriptor and dispatches it.
func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
- n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs())
- if n == 0 || err != nil {
+ n, err := rawfile.BlockingReadvUntilStopped(d.efd, d.fd, d.buf.nextIovecs())
+ if n <= 0 || err != nil {
return false, err
}
@@ -183,6 +221,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
// dispatches them.
type recvMMsgDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -206,7 +245,12 @@ const (
)
func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
d := &recvMMsgDispatcher{
+ stopFd: stopFd,
fd: fd,
e: e,
bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
@@ -234,8 +278,8 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
d.msgHdrs[k].Msg.SetIovlen(iovLen)
}
- nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
- if err != nil {
+ nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.efd, d.fd, d.msgHdrs)
+ if nMsgs == -1 || err != nil {
return false, err
}
// Process each of received packets.
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 7012d8829..ca1f9c08d 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,19 +76,8 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- // Construct data as the unparsed portion for the loopback packet.
- data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
-
- // Because we're immediately turning around and writing the packet back
- // to the rx path, we intentionally don't preserve the remote and local
- // link addresses from the stack.Route we're passed.
- newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: data,
- })
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt)
-
- return nil
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ return e.WriteRawPacket(pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -103,3 +92,19 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ // Construct data as the unparsed portion for the loopback packet.
+ data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+
+ // Because we're immediately turning around and writing the packet back
+ // to the rx path, we intentionally don't preserve the remote and local
+ // link addresses from the stack.Route we're passed.
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: data,
+ })
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 3e2a1aa94..844f5959b 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -131,6 +131,11 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*InjectableEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 3e816b0c7..83a6c1cc8 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -60,16 +60,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.mu.RLock()
- d := e.dispatcher
- e.mu.RUnlock()
- if d != nil {
- d.DeliverOutboundPacket(remote, local, protocol, pkt)
- }
-}
-
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -152,3 +142,8 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.child.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.child.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
deleted file mode 100644
index 6fff160ce..000000000
--- a/pkg/tcpip/link/packetsocket/BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "packetsocket",
- srcs = ["endpoint.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/link/nested",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
deleted file mode 100644
index e01837e2d..000000000
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package packetsocket provides a link layer endpoint that provides the ability
-// to loop outbound packets to any AF_PACKET sockets that may be interested in
-// the outgoing packet.
-package packetsocket
-
-import (
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/link/nested"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type endpoint struct {
- nested.Endpoint
-}
-
-// New creates a new packetsocket LinkEndpoint.
-func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- e := &endpoint{}
- e.Endpoint.Init(lower, e)
- return e
-}
-
-// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
- return e.Endpoint.WritePacket(r, protocol, pkt)
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
- }
-
- return e.Endpoint.WritePackets(r, pkts, proto)
-}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 5030b6ba1..3ed0aa3fe 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -121,3 +121,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.
func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index b1a28491d..b41e3e2fa 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -108,13 +108,15 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
-}
-
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ // nil means the NIC is being removed.
+ if dispatcher == nil {
+ e.lower.Attach(nil)
+ e.Wait()
+ e.dispatcher = nil
+ return
+ }
e.dispatcher = dispatcher
e.lower.Attach(e)
}
@@ -221,3 +223,8 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.lower.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
index 298bad55d..f2c230720 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVQ $0x0, R10 // sigmask parameter which isn't used here
MOVQ $0x10f, AX // SYS_PPOLL
SYSCALL
- CMPQ AX, $0xfffffffffffff001
+ CMPQ AX, $0xfffffffffffff002
JLS ok
MOVQ $-1, n+24(FP)
NEGQ AX
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
index b62888b93..8807586c7 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVD $0x0, R3 // sigmask parameter which isn't used here
MOVD $0x49, R8 // SYS_PPOLL
SVC
- CMP $0xfffffffffffff001, R0
+ CMP $0xfffffffffffff002, R0
BLS ok
MOVD $-1, R1
MOVD R1, n+24(FP)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
index 2206fe0e6..c1438da21 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux && !amd64 && !arm64
// +build linux,!amd64,!arm64
package rawfile
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 5002245a1..0b7b9e3de 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build ((linux && amd64) || (linux && arm64)) && go1.12
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.18
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package rawfile
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index 9743e70ea..7e21a78d4 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package rawfile
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
index 8f4bd60da..1b88c309b 100644
--- a/pkg/tcpip/link/rawfile/errors_test.go
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package rawfile
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index ba92aedbc..87a0b9a62 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package rawfile contains utilities for using the netstack with raw host
@@ -19,12 +20,66 @@
package rawfile
import (
+ "reflect"
"unsafe"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// SizeofIovec is the size of a unix.Iovec in bytes.
+const SizeofIovec = unsafe.Sizeof(unix.Iovec{})
+
+// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a
+// host system call in a single array.
+const MaxIovs = 1024
+
+// IovecFromBytes returns a unix.Iovec representing bs.
+//
+// Preconditions: len(bs) > 0.
+func IovecFromBytes(bs []byte) unix.Iovec {
+ iov := unix.Iovec{
+ Base: &bs[0],
+ }
+ iov.SetLen(len(bs))
+ return iov
+}
+
+func bytesFromIovec(iov unix.Iovec) (bs []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ sh.Data = uintptr(unsafe.Pointer(iov.Base))
+ sh.Len = int(iov.Len)
+ sh.Cap = int(iov.Len)
+ return
+}
+
+// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) ==
+// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >=
+// max, AppendIovecFromBytes replaces the final iovec in iovs with one that
+// also includes the contents of bs. Note that this implies that
+// AppendIovecFromBytes is only usable when the returned iovec slice is used as
+// the source of a write.
+func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec {
+ if len(bs) == 0 {
+ return iovs
+ }
+ if len(iovs) < max {
+ return append(iovs, IovecFromBytes(bs))
+ }
+ iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...))
+ return iovs
+}
+
+// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
+type MMsgHdr struct {
+ Msg unix.Msghdr
+ Len uint32
+ _ [4]byte
+}
+
+// SizeofMMsgHdr is the size of a MMsgHdr in bytes.
+const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{})
+
// GetMTU determines the MTU of a network interface device.
func GetMTU(name string) (uint32, error) {
fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0)
@@ -115,53 +170,77 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
}
}
-// BlockingReadv reads from a file descriptor that is set up as non-blocking and
-// stores the data in a list of iovecs buffers. If no data is available, it will
-// block in a poll() syscall until the file descriptor becomes readable.
-func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
+// BlockingReadvUntilStopped reads from a file descriptor that is set up as
+// non-blocking and stores the data in a list of iovecs buffers. If no data is
+// available, it will block in a poll() syscall until the file descriptor
+// becomes readable or stop is signalled (efd becomes readable). Returns -1 in
+// the latter case.
+func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
return int(n), nil
}
-
- event := PollEvent{
- FD: int32(fd),
- Events: 1, // POLLIN
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
+ }
+ stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
+ if stopped {
+ return -1, nil
}
-
- _, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
return 0, TranslateErrno(e)
}
}
}
-// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
-type MMsgHdr struct {
- Msg unix.Msghdr
- Len uint32
- _ [4]byte
-}
-
-// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
-// and stores the received messages in a slice of MMsgHdr structures. If no data
-// is available, it will block in a poll() syscall until the file descriptor
-// becomes readable.
-func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
+// BlockingRecvMMsgUntilStopped reads from a file descriptor that is set up as
+// non-blocking and stores the received messages in a slice of MMsgHdr
+// structures. If no data is available, it will block in a poll() syscall until
+// the file descriptor becomes readable or stop is signalled (efd becomes
+// readable). Returns -1 in the latter case.
+func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
for {
n, _, e := unix.RawSyscall6(unix.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
if e == 0 {
return int(n), nil
}
- event := PollEvent{
- FD: int32(fd),
- Events: 1, // POLLIN
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
}
- if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != unix.EINTR {
+ stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
+ if stopped {
+ return -1, nil
+ }
+ if e != 0 && e != unix.EINTR {
return 0, TranslateErrno(e)
}
}
}
+
+// BlockingPollUntilStopped polls for events on fd or until a stop is signalled
+// on the event fd efd. Returns true if stopped, i.e., efd has event POLLIN.
+func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) {
+ pevents := [...]PollEvent{
+ {
+ FD: int32(efd),
+ Events: unix.POLLIN,
+ },
+ {
+ FD: int32(fd),
+ Events: events,
+ },
+ }
+ _, errno := BlockingPoll(&pevents[0], len(pevents), nil)
+ if errno != 0 {
+ return pevents[0].Revents&unix.POLLIN != 0, errno
+ }
+
+ if pevents[1].Revents&unix.POLLHUP != 0 || pevents[1].Revents&unix.POLLERR != 0 {
+ errno = unix.ECONNRESET
+ }
+
+ return pevents[0].Revents&unix.POLLIN != 0, errno
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index 8e6f3e5e3..e882a128c 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package sharedmem
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index df9a0b90a..66efe6472 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package sharedmem provides the implemention of data-link layer endpoints
@@ -201,6 +202,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
eth.Encode(ethHdr)
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 0f72d4e95..d6d953085 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package sharedmem
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
index c16c19647..d3edede63 100644
--- a/pkg/tcpip/link/sniffer/pcap.go
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -14,7 +14,14 @@
package sniffer
-import "time"
+import (
+ "encoding"
+ "encoding/binary"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
type pcapHeader struct {
// MagicNumber is the file magic number.
@@ -39,28 +46,38 @@ type pcapHeader struct {
Network uint32
}
-const pcapPacketHeaderLen = 16
-
-type pcapPacketHeader struct {
- // Seconds is the timestamp seconds.
- Seconds uint32
-
- // Microseconds is the timestamp microseconds.
- Microseconds uint32
+var _ encoding.BinaryMarshaler = (*pcapPacket)(nil)
- // IncludedLength is the number of octets of packet saved in file.
- IncludedLength uint32
-
- // OriginalLength is the actual length of packet.
- OriginalLength uint32
+type pcapPacket struct {
+ timestamp time.Time
+ packet *stack.PacketBuffer
+ maxCaptureLen int
}
-func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
- now := time.Now()
- return pcapPacketHeader{
- Seconds: uint32(now.Unix()),
- Microseconds: uint32(now.Nanosecond() / 1000),
- IncludedLength: incLen,
- OriginalLength: orgLen,
+func (p *pcapPacket) MarshalBinary() ([]byte, error) {
+ packetSize := p.packet.Size()
+ captureLen := p.maxCaptureLen
+ if packetSize < captureLen {
+ captureLen = packetSize
+ }
+ b := make([]byte, 16+captureLen)
+ binary.BigEndian.PutUint32(b[0:4], uint32(p.timestamp.Unix()))
+ binary.BigEndian.PutUint32(b[4:8], uint32(p.timestamp.Nanosecond()/1000))
+ binary.BigEndian.PutUint32(b[8:12], uint32(captureLen))
+ binary.BigEndian.PutUint32(b[12:16], uint32(packetSize))
+ w := tcpip.SliceWriter(b[16:])
+ for _, v := range p.packet.Views() {
+ if captureLen == 0 {
+ break
+ }
+ if len(v) > captureLen {
+ v = v[:captureLen]
+ }
+ n, err := w.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ captureLen -= n
}
+ return b, nil
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 2d6a3a833..2afa95af0 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -87,11 +87,7 @@ func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoin
}
func zoneOffset() (int32, error) {
- loc, err := time.LoadLocation("Local")
- if err != nil {
- return 0, err
- }
- date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
+ date := time.Date(0, 0, 0, 0, 0, 0, 0, time.Local)
_, offset := date.Zone()
return int32(offset), nil
}
@@ -117,8 +113,9 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around
// another endpoint and logs packets as they traverse the endpoint.
//
-// Packets are logged to writer in the pcap format. A sniffer created with this
-// function will not emit packets using the standard log package.
+// Each packet is written to writer in the pcap format in a single Write call
+// without synchronization. A sniffer created with this function will not emit
+// packets using the standard log package.
//
// snapLen is the maximum amount of a packet to be saved. Packets with a length
// less than or equal to snapLen will be saved in their entirety. Longer
@@ -143,43 +140,23 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
-}
-
func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
logPacket(e.logPrefix, dir, protocol, pkt)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
- totalLength := pkt.Size()
- length := totalLength
- if max := int(e.maxPCAPLen); length > max {
- length = max
+ packet := pcapPacket{
+ timestamp: time.Now(),
+ packet: pkt,
+ maxCaptureLen: int(e.maxPCAPLen),
}
- if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil {
+ b, err := packet.MarshalBinary()
+ if err != nil {
panic(err)
}
- write := func(b []byte) {
- if len(b) > length {
- b = b[:length]
- }
- for len(b) != 0 {
- n, err := writer.Write(b)
- if err != nil {
- panic(err)
- }
- b = b[n:]
- length -= n
- }
- }
- for _, v := range pkt.Views() {
- if length == 0 {
- break
- }
- write(v)
+ if _, err := writer.Write(b); err != nil {
+ panic(err)
}
}
}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 7656cca6a..c3e4c3455 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -26,11 +26,11 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 36af2a029..fa2131c28 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -18,8 +18,8 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -88,12 +88,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Input validation.
if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
prefix := "tun"
@@ -108,7 +108,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
d.endpoint = endpoint
@@ -125,7 +125,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, syserror.EOPNOTSUPP
+ return nil, linuxerr.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
@@ -159,7 +159,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
}
@@ -170,10 +170,10 @@ func (d *Device) Write(data []byte) (int64, error) {
endpoint := d.endpoint
d.mu.RUnlock()
if endpoint == nil {
- return 0, syserror.EBADFD
+ return 0, linuxerr.EBADFD
}
if !endpoint.IsAttached() {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
dataLen := int64(len(data))
@@ -207,6 +207,15 @@ func (d *Device) Write(data []byte) (int64, error) {
protocol = pktInfoHdr.Protocol()
case ethHdr != nil:
protocol = ethHdr.Type()
+ case d.flags.TUN:
+ // TUN interface with IFF_NO_PI enabled, thus
+ // we need to determine protocol from version field
+ version := data[0] >> 4
+ if version == 4 {
+ protocol = header.IPv4ProtocolNumber
+ } else if version == 6 {
+ protocol = header.IPv6ProtocolNumber
+ }
}
// Try to determine remote link address, default zero.
@@ -233,13 +242,13 @@ func (d *Device) Read() ([]byte, error) {
endpoint := d.endpoint
d.mu.RUnlock()
if endpoint == nil {
- return nil, syserror.EBADFD
+ return nil, linuxerr.EBADFD
}
for {
info, ok := endpoint.Read()
if !ok {
- return nil, syserror.ErrWouldBlock
+ return nil, linuxerr.ErrWouldBlock
}
v, ok := d.encodePkt(&info)
@@ -264,13 +273,6 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
vv.AppendView(buffer.View(hdr))
}
- // If the packet does not already have link layer header, and the route
- // does not exist, we can't compute it. This is possibly a raw packet, tun
- // device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 {
- return nil, false
- }
-
// Ethernet header (TAP only).
if d.flags.TAP {
// Add ethernet header if not provided.
diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go
index 0591fbd63..db4338e79 100644
--- a/pkg/tcpip/link/tun/tun_unsafe.go
+++ b/pkg/tcpip/link/tun/tun_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package tun contains methods to open TAP and TUN devices.
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a95602aa5..116e4defb 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -59,15 +59,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- if !e.dispatchGate.Enter() {
- return
- }
- e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
- e.dispatchGate.Leave()
-}
-
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -155,3 +146,6 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index a71400ee9..b0e4237bd 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -80,6 +80,11 @@ func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBuffe
return pkts.Len(), nil
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 7b1ff44f4..c0179104a 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -23,8 +23,10 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6515c31e5..e08243547 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -272,7 +272,6 @@ type protocol struct {
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
-func (p *protocol) DefaultPrefixLen() int { return 0 }
func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 5fcbfeaa2..061cc35ae 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext
t.Fatalf("CreateNIC failed: %s", err)
}
- if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %s", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: stackAddr.WithPrefix(),
+ }
+ if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
tc.s.SetRouteTable([]tcpip.Route{{
@@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 0b51563cd..1261ad414 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -126,7 +126,7 @@ func (m *mockMulticastGroupProtocol) sendQueuedReports() {
// Precondition: m.mu must be read locked.
func (m *mockMulticastGroupProtocol) Enabled() bool {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
}
@@ -138,11 +138,11 @@ func (m *mockMulticastGroupProtocol) Enabled() bool {
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
@@ -155,11 +155,11 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index 605e9ef8d..4d4d98caf 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -101,6 +101,11 @@ func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return heade
func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
// how many random bytes will be copied in the Transport Header.
// extraHeaderReserveLength indicates how much extra space will be reserved for
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index bd63e0289..87f650661 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "bytes"
"fmt"
"strings"
"testing"
@@ -32,8 +33,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const nicID = 1
@@ -88,6 +91,7 @@ type testObject struct {
dataCalls int
controlCalls int
+ rawCalls int
}
// checkValues verifies that the transport protocol, data contents, src & dst
@@ -148,6 +152,10 @@ func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpi
t.controlCalls++
}
+func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+ t.rawCalls++
+}
+
// Attach is only implemented to satisfy the LinkEndpoint interface.
func (*testObject) Attach(stack.NetworkDispatcher) {}
@@ -225,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv4.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
@@ -241,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv6.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
@@ -264,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c
}
v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
}
v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
}
return s, e
@@ -705,8 +725,8 @@ func TestReceive(t *testing.T) {
if !ok {
t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
}
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
} else {
ep.DecRef()
}
@@ -717,7 +737,10 @@ func TestReceive(t *testing.T) {
}
test.handlePacket(t, ep, &nic)
if nic.testObject.dataCalls != 1 {
- t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 1 {
+ t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
}
if got := stat.Value(); got != 1 {
t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
@@ -874,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -960,15 +983,18 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
+ t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 0 {
+ t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls)
}
// Send second segment.
@@ -977,7 +1003,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
})
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 1 {
+ t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
}
}
@@ -1220,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1287,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name string
protoFactory stack.NetworkProtocolFactory
protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.Address
+ nicAddr tcpip.AddressWithPrefix
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
@@ -1297,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1310,7 +1339,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1338,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with IHL too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1351,7 +1380,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
ip.SetHeaderLength(header.IPv4MinimumSize - 1)
return hdr.View().ToVectorisedView()
@@ -1362,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1370,7 +1399,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
@@ -1380,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 minimum size",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1388,7 +1417,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1416,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
@@ -1430,7 +1459,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
Options: ipv4Options,
})
return hdr.View().ToVectorisedView()
@@ -1461,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options and data across views",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
@@ -1469,7 +1498,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
Options: ipv4Options,
})
vv := buffer.View(ip).ToVectorisedView()
@@ -1502,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
@@ -1515,7 +1544,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1542,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 with extension header",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
@@ -1560,7 +1589,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1587,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 minimum size",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1595,7 +1624,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1622,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 too small",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1630,7 +1659,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
@@ -1646,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}{
{
name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))),
},
{
name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))),
},
}
@@ -1663,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.protoNum,
+ AddressWithPrefix: test.nicAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
- r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
}
defer r.Release()
@@ -2018,3 +2051,97 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
})
}
}
+
+func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.AddressWithPrefix
+ payloadOffset int
+ }{
+ {
+ name: "IPv4",
+ proto: header.IPv4ProtocolNumber,
+ addr: localIPv4AddrWithPrefix,
+ payloadOffset: header.IPv4MinimumSize,
+ },
+ {
+ name: "IPv6",
+ proto: header.IPv6ProtocolNumber,
+ addr: localIPv6AddrWithPrefix,
+ payloadOffset: 0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ RawFactory: raw.EndpointFactory{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.proto,
+ AddressWithPrefix: test.addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: test.addr.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
+ }
+ defer ep.Close()
+
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.addr.Address,
+ },
+ }
+ data := []byte{1, 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := ep.Write(&r, writeOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
+ }
+
+ // Wait for the endpoint to become readable.
+ <-ch
+
+ var w bytes.Buffer
+ rr, err := ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if err != nil {
+ t.Fatalf("ep.Read(...): %s", err)
+ }
+ if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" {
+ t.Errorf("remote addr mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index c90974693..2257f728e 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -39,7 +39,6 @@ go_test(
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/internal/testutil",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 5f6b0c6af..2aa38eb98 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their
- // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4bd6f462e..c6576fcbc 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// cycles.
func TestIGMPV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
- addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
@@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) {
// The initial set of IGMP reports that were queued should be sent once an
// address is assigned.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackAddr,
+ PrefixLen: defaultPrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if got := reportStat.Value(); got != 1 {
t.Errorf("got reportStat.Value() = %d, want = 1", got)
@@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e, s, _ := createStack(t, true)
for _, address := range test.stackAddresses {
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: address,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
stats := s.Stats()
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c99297a51..aef789b4c 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -240,7 +240,7 @@ func (e *endpoint) Enable() tcpip.Error {
}
// Create an endpoint to receive broadcast packets on this interface.
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
return err
}
@@ -429,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -614,10 +614,6 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ipH.SetSourceAddress(r.LocalAddress())
}
- // Set the destination. If the packet already included a destination, it will
- // be part of the route anyways.
- ipH.SetDestinationAddress(r.RemoteAddress())
-
// Set the packet ID when zero.
if ipH.ID() == 0 {
// RFC 6864 section 4.3 mandates uniqueness of ID values for
@@ -861,6 +857,14 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
pkt.NICID = e.nic.ID()
+
+ // Raw socket packets are delivered based solely on the transport protocol
+ // number. We only require that the packet be valid IPv4, and that they not
+ // be fragmented.
+ if !h.More() && h.FragmentOffset() == 0 {
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
+ }
+
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -995,6 +999,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
// to do it here.
h.SetTotalLength(uint16(pkt.Data().Size() + len(h)))
h.SetFlagsFragmentOffset(0, 0)
+
+ // Now that the packet is reassembled, it can be sent to raw sockets.
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
stats.ip.PacketsDelivered.Increment()
@@ -1068,11 +1075,11 @@ func (e *endpoint) Close() {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err == nil {
e.mu.igmp.sendQueuedReports()
}
@@ -1219,11 +1226,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
-// DefaultPrefixLen returns the IPv4 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv4AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 4a4448cf9..e7b5b3ea2 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -32,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
@@ -102,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if err := ep.Connect(randomAddr); err != nil {
t.Errorf("Connect failed: %v", err)
@@ -357,8 +360,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
expectedEmittedPacketCount := 1
@@ -370,8 +373,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1185,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
// Default routes for IPv4 so ICMP can find a route to the remote
@@ -1746,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2013,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
for _, f := range test.fragments {
@@ -2062,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2238,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -2309,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) {
const (
nicID = 1
- addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1
+ addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2
+ addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3
)
// Build and return a UDP header containing payload.
@@ -2704,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2986,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
)
- if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask(header.IPv4Broadcast)
@@ -3162,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3286,8 +3305,12 @@ func TestCloseLocking(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
@@ -3339,7 +3362,7 @@ func TestCloseLocking(t *testing.T) {
defer wg.Done()
for i := 0; i < iterations; i++ {
- if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil {
t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
return
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 23fc94303..94caaae6c 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) {
sent := e.stats.icmp.packetsSent
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set. See icmp/protocol.go:protocol.Parse for a full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize)
if !ok {
received.invalid.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index c2e9544c1..3b4c235fa 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -90,6 +90,10 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st
return stack.TransportPacketHandled
}
+func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+ // No-op.
+}
+
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
@@ -221,8 +225,8 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -403,8 +407,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
+ llProtocolAddr0 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
@@ -412,8 +420,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
+ llProtocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr1.WithPrefix(),
+ }
+ if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err)
}
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -686,8 +698,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -879,8 +895,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1061,8 +1081,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1236,8 +1260,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -1411,8 +1439,8 @@ func TestPacketQueing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1665,8 +1693,12 @@ func TestCallsToNeighborCache(t *testing.T) {
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
{
@@ -1700,8 +1732,8 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 12763add6..c824e27fa 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -344,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
defer e.mu.Unlock()
- e.mu.ndp.invalidateDefaultRouter(rtr)
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr})
}
// SetNDPConfigurations implements NDPEndpoint.
@@ -755,9 +758,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -928,10 +931,6 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ipH.SetSourceAddress(r.LocalAddress())
}
- // Set the destination. If the packet already included a destination, it will
- // be part of the route anyways.
- ipH.SetDestinationAddress(r.RemoteAddress())
-
// Populate the packet buffer's network header and don't allow an invalid
// packet to be sent.
//
@@ -1129,6 +1128,11 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
pkt.NICID = e.nic.ID()
+
+ // Raw socket packets are delivered based solely on the transport protocol
+ // number. We only require that the packet be valid IPv6.
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
+
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1624,12 +1628,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
defer e.mu.Unlock()
- return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+ return e.addAndAcquirePermanentAddressLocked(addr, properties)
}
// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
@@ -1639,8 +1643,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err != nil {
return nil, err
}
@@ -2007,11 +2011,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen returns the IPv6 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index d2a23fd4f..0735ebb23 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -41,12 +41,12 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
- addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02")
+ addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03")
// Tests use the extension header identifier values as uint8 instead of
// header.IPv6ExtensionHeaderIdentifier.
@@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err)
}
// Should receive a packet destined to the solicited node address of
// addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr3.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err)
}
// Should still receive a packet destined to the solicited node address of
@@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
@@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Add a default route so that a return packet knows where to go.
@@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) {
func TestFragmentReassemblyTimeout(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
)
- if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
@@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
@@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index bc9cf6999..3e5c438d3 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Note, we will still expect to send a report for the global address's
// solicited node address from the unspecified address as per RFC 3590
// section 4.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ globalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: globalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err)
}
reportCounter++
if got := reportStat.Value(); got != reportCounter {
@@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Adding a link-local address should send a report for its solicited node
// address and globalMulticastAddr.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ linkLocalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err)
}
if dadResolutionTime != 0 {
reportCounter++
@@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 851cd6e75..938427420 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -54,6 +54,11 @@ const (
// Advertisements, as a host.
defaultDiscoverDefaultRouters = true
+ // defaultDiscoverMoreSpecificRoutes is the default configuration for
+ // whether or not to discover more-specific routes from incoming Router
+ // Advertisements, as a host.
+ defaultDiscoverMoreSpecificRoutes = true
+
// defaultDiscoverOnLinkPrefixes is the default configuration for
// whether or not to discover on-link prefixes from incoming Router
// Advertisements' Prefix Information option, as a host.
@@ -78,13 +83,13 @@ const (
// we cannot have a negative delay.
minimumMaxRtrSolicitationDelay = 0
- // MaxDiscoveredDefaultRouters is the maximum number of discovered
- // default routers. The stack should stop discovering new routers after
- // discovering MaxDiscoveredDefaultRouters routers.
+ // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link
+ // routes. The stack should stop discovering new off-link routes after
+ // this limit is reached.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- MaxDiscoveredDefaultRouters = 10
+ MaxDiscoveredOffLinkRoutes = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
@@ -127,25 +132,17 @@ const (
// maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
// SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
maxSLAACAddrLocalRegenAttempts = 10
-)
-var (
// MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
// Lifetime to update the valid lifetime of a generated address by
// SLAAC.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// Min = 2hrs.
MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
// MaxDesyncFactor is the upper bound for the preferred lifetime's desync
// factor for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// Must be greater than 0.
//
// Max = 10m (from RFC 4941 section 5).
@@ -154,9 +151,6 @@ var (
// MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
// maximum preferred lifetime for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// This value guarantees that a temporary address is preferred for at
// least 1hr if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
@@ -164,9 +158,6 @@ var (
// MinMaxTempAddrValidLifetime is the minimum value allowed for the
// maximum valid lifetime for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// This value guarantees that a temporary address is valid for at least
// 2hrs if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrValidLifetime = 2 * time.Hour
@@ -214,28 +205,23 @@ type NDPDispatcher interface {
// is also not permitted to call into the stack.
OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult)
- // OnDefaultRouterDiscovered is called when a new default router is
- // discovered. Implementations must return true if the newly discovered
- // router should be remembered.
+ // OnOffLinkRouteUpdated is called when an off-link route is updated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool
+ OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference)
- // OnDefaultRouterInvalidated is called when a discovered default router that
- // was remembered is invalidated.
+ // OnOffLinkRouteInvalidated is called when an off-link route is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address)
+ OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address)
// OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
- // Implementations must return true if the newly discovered on-link prefix
- // should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool
+ OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet)
// OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
// was remembered is invalidated.
@@ -245,13 +231,11 @@ type NDPDispatcher interface {
OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet)
// OnAutoGenAddress is called when a new prefix with its autonomous address-
- // configuration flag set is received and SLAAC was performed. Implementations
- // may prevent the stack from assigning the address to the NIC by returning
- // false.
+ // configuration flag set is received and SLAAC was performed.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
- OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix)
// OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
// is deprecated, but is still considered valid. Note, if an address is
@@ -373,12 +357,18 @@ type NDPConfigurations struct {
// DiscoverDefaultRouters determines whether or not default routers are
// discovered from Router Advertisements, as per RFC 4861 section 6. This
- // configuration is ignored if HandleRAs is false.
+ // configuration is ignored if RAs will not be processed (see HandleRAs).
DiscoverDefaultRouters bool
+ // DiscoverMoreSpecificRoutes determines whether or not more specific routes
+ // are discovered from Router Advertisements, as per RFC 4191. This
+ // configuration is ignored if RAs will not be processed (see HandleRAs).
+ DiscoverMoreSpecificRoutes bool
+
// DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
// discovered from Router Advertisements' Prefix Information option, as per
- // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
+ // RFC 4861 section 6. This configuration is ignored if RAs will not be
+ // processed (see HandleRAs).
DiscoverOnLinkPrefixes bool
// AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
@@ -429,6 +419,7 @@ func DefaultNDPConfigurations() NDPConfigurations {
MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
HandleRAs: defaultHandleRAs,
DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverMoreSpecificRoutes: defaultDiscoverMoreSpecificRoutes,
DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
@@ -469,6 +460,11 @@ type timer struct {
timer tcpip.Timer
}
+type offLinkRoute struct {
+ dest tcpip.Subnet
+ router tcpip.Address
+}
+
// ndpState is the per-Interface NDP state.
type ndpState struct {
// Do not allow overwriting this state.
@@ -483,8 +479,8 @@ type ndpState struct {
// The DAD timers to send the next NS message, or resolve the address.
dad ip.DAD
- // The default routers discovered through Router Advertisements.
- defaultRouters map[tcpip.Address]defaultRouterState
+ // The off-link routes discovered through Router Advertisements.
+ offLinkRoutes map[offLinkRoute]offLinkRouteState
// rtrSolicitTimer is the timer used to send the next router solicitation
// message.
@@ -512,10 +508,12 @@ type ndpState struct {
temporaryAddressDesyncFactor time.Duration
}
-// defaultRouterState holds data associated with a default router discovered by
+// offLinkRouteState holds data associated with an off-link route discovered by
// a Router Advertisement (RA).
-type defaultRouterState struct {
- // Job to invalidate the default router.
+type offLinkRouteState struct {
+ prf header.NDPRoutePreference
+
+ // Job to invalidate the route.
//
// Must not be nil.
invalidationJob *tcpip.Job
@@ -571,11 +569,11 @@ type slaacPrefixState struct {
// Must not be nil.
invalidationJob *tcpip.Job
- // Nonzero only when the address is not valid forever.
- validUntil tcpip.MonotonicTime
+ // nil iff the address is valid forever.
+ validUntil *tcpip.MonotonicTime
- // Nonzero only when the address is not preferred forever.
- preferredUntil tcpip.MonotonicTime
+ // nil iff the address is preferred forever.
+ preferredUntil *tcpip.MonotonicTime
// State associated with the stable address generated for the prefix.
stableAddr struct {
@@ -733,30 +731,22 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Is the IPv6 endpoint configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
- rtr, ok := ndp.defaultRouters[ip]
- rl := ra.RouterLifetime()
- switch {
- case !ok && rl != 0:
- // This is a new default router we are discovering.
+ prf := ra.DefaultRouterPreference()
+ if prf == header.ReservedRoutePreference {
+ // As per RFC 4191 section 2.2,
//
- // Only remember it if we currently know about less than
- // MaxDiscoveredDefaultRouters routers.
- if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
- ndp.rememberDefaultRouter(ip, rl)
- }
-
- case ok && rl != 0:
- // This is an already discovered default router. Update
- // the invalidation job.
- rtr.invalidationJob.Cancel()
- rtr.invalidationJob.Schedule(rl)
- ndp.defaultRouters[ip] = rtr
-
- case ok && rl == 0:
- // We know about the router but it is no longer to be
- // used as a default router so invalidate it.
- ndp.invalidateDefaultRouter(ip)
+ // Prf (Default Router Preference)
+ //
+ // If the Reserved (10) value is received, the receiver MUST treat the
+ // value as if it were (00).
+ //
+ // Note that the value 00 is the medium (default) router preference value.
+ prf = header.MediumRoutePreference
}
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf)
}
// TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
@@ -808,61 +798,107 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if opt.AutonomousAddressConfigurationFlag() {
ndp.handleAutonomousPrefixInformation(opt)
}
+
+ case header.NDPRouteInformation:
+ if !ndp.configs.DiscoverMoreSpecificRoutes {
+ continue
+ }
+
+ dest, err := opt.Prefix()
+ if err != nil {
+ panic(fmt.Sprintf("%T.Prefix(): %s", opt, err))
+ }
+
+ prf := opt.RoutePreference()
+ if prf == header.ReservedRoutePreference {
+ // As per RFC 4191 section 2.3,
+ //
+ // Prf (Route Preference)
+ // 2-bit signed integer. The Route Preference indicates
+ // whether to prefer the router associated with this prefix
+ // over others, when multiple identical prefixes (for
+ // different routers) have been received. If the Reserved
+ // (10) value is received, the Route Information Option MUST
+ // be ignored.
+ continue
+ }
+
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: dest, router: ip}, opt.RouteLifetime(), prf)
}
// TODO(b/141556115): Do (MTU) Parameter Discovery.
}
}
-// invalidateDefaultRouter invalidates a discovered default router.
+// invalidateOffLinkRoute invalidates a discovered off-link route.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
- rtr, ok := ndp.defaultRouters[ip]
-
- // Is the router still discovered?
+func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) {
+ state, ok := ndp.offLinkRoutes[route]
if !ok {
- // ...Nope, do nothing further.
return
}
- rtr.invalidationJob.Cancel()
- delete(ndp.defaultRouters, ip)
+ state.invalidationJob.Cancel()
+ delete(ndp.offLinkRoutes, route)
- // Let the integrator know a discovered default router is invalidated.
+ // Let the integrator know a discovered off-link route is invalidated.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
+ ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router)
}
}
-// rememberDefaultRouter remembers a newly discovered default router with IPv6
-// link-local address ip with lifetime rl.
-//
-// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
+// handleOffLinkRouteDiscovery handles the discovery of an off-link route.
//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) {
ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
- // Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
- // Informed by the integrator to not remember the router, do
- // nothing further.
- return
- }
+ state, ok := ndp.offLinkRoutes[route]
+ switch {
+ case !ok && lifetime != 0:
+ // This is a new route we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOffLinkRoutes routers.
+ if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes {
+ // Inform the integrator when we discovered an off-link route.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+
+ state := offLinkRouteState{
+ prf: prf,
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateOffLinkRoute(route)
+ }),
+ }
- state := defaultRouterState{
- invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
- ndp.invalidateDefaultRouter(ip)
- }),
- }
+ state.invalidationJob.Schedule(lifetime)
+
+ ndp.offLinkRoutes[route] = state
+ }
- state.invalidationJob.Schedule(rl)
+ case ok && lifetime != 0:
+ // This is an already discovered off-link route. Update the lifetime.
+ state.invalidationJob.Cancel()
+ state.invalidationJob.Schedule(lifetime)
- ndp.defaultRouters[ip] = state
+ if prf != state.prf {
+ state.prf = prf
+
+ // Inform the integrator about route preference updates.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+ }
+
+ ndp.offLinkRoutes[route] = state
+
+ case ok && lifetime == 0:
+ // The already discovered off-link route is no longer considered valid so we
+ // invalidate it immediately.
+ ndp.invalidateOffLinkRoute(route)
+ }
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -878,11 +914,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
- // Informed by the integrator to not remember the prefix, do
- // nothing further.
- return
- }
+ ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix)
state := onLinkPrefixState{
invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
@@ -1055,7 +1087,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// The time an address is preferred until is needed to properly generate the
// address.
if pl < header.NDPInfiniteLifetime {
- state.preferredUntil = now.Add(pl)
+ t := now.Add(pl)
+ state.preferredUntil = &t
}
if !ndp.generateSLAACAddr(prefix, &state) {
@@ -1073,7 +1106,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if vl < header.NDPInfiniteLifetime {
state.invalidationJob.Schedule(vl)
- state.validUntil = now.Add(vl)
+ t := now.Add(vl)
+ state.validUntil = &t
}
// If the address is assigned (DAD resolved), generate a temporary address.
@@ -1096,16 +1130,17 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
- // Informed by the integrator not to add the address.
- return nil
- }
-
- addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{
+ PEB: stack.FirstPrimaryEndpoint,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ })
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
+ ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr)
+
return addressEndpoint
}
@@ -1181,7 +1216,8 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
state.stableAddr.localGenerationFailures++
}
- if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, ndp.ep.protocol.stack.Clock().NowMonotonic().Sub(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ deprecated := state.preferredUntil != nil && !state.preferredUntil.After(ndp.ep.protocol.stack.Clock().NowMonotonic())
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, deprecated); addressEndpoint != nil {
state.stableAddr.addressEndpoint = addressEndpoint
state.generationAttempts++
return true
@@ -1242,7 +1278,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// address is the lower of the valid lifetime of the stable address or the
// maximum temporary address valid lifetime.
vl := ndp.configs.MaxTempAddrValidLifetime
- if prefixState.validUntil != (tcpip.MonotonicTime{}) {
+ if prefixState.validUntil != nil {
if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
vl = prefixVL
}
@@ -1258,7 +1294,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// maximum temporary address preferred lifetime - the temporary address desync
// factor.
pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
- if prefixState.preferredUntil != (tcpip.MonotonicTime{}) {
+ if prefixState.preferredUntil != nil {
if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
// Respect the preferred lifetime of the prefix, as per RFC 4941 section
// 3.3 step 4.
@@ -1400,9 +1436,10 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
if !deprecated {
prefixState.deprecationJob.Schedule(pl)
}
- prefixState.preferredUntil = now.Add(pl)
+ t := now.Add(pl)
+ prefixState.preferredUntil = &t
} else {
- prefixState.preferredUntil = tcpip.MonotonicTime{}
+ prefixState.preferredUntil = nil
}
// As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
@@ -1420,14 +1457,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// Handle the infinite valid lifetime separately as we do not schedule a
// job in this case.
prefixState.invalidationJob.Cancel()
- prefixState.validUntil = tcpip.MonotonicTime{}
+ prefixState.validUntil = nil
} else {
var effectiveVl time.Duration
var rl time.Duration
// If the prefix was originally set to be valid forever, assume the
// remaining time to be the maximum possible value.
- if prefixState.validUntil == (tcpip.MonotonicTime{}) {
+ if prefixState.validUntil == nil {
rl = header.NDPInfiniteLifetime
} else {
rl = prefixState.validUntil.Sub(now)
@@ -1442,7 +1479,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
if effectiveVl != 0 {
prefixState.invalidationJob.Cancel()
prefixState.invalidationJob.Schedule(effectiveVl)
- prefixState.validUntil = now.Add(effectiveVl)
+ t := now.Add(effectiveVl)
+ prefixState.validUntil = &t
}
}
@@ -1462,8 +1500,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// maximum temporary address valid lifetime. Note, the valid lifetime of a
// temporary address is relative to the address's creation time.
validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
- if prefixState.validUntil != (tcpip.MonotonicTime{}) && validUntil.Sub(prefixState.validUntil) > 0 {
- validUntil = prefixState.validUntil
+ if prefixState.validUntil != nil && prefixState.validUntil.Before(validUntil) {
+ validUntil = *prefixState.validUntil
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
@@ -1482,14 +1520,15 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// desync factor. Note, the preferred lifetime of a temporary address is
// relative to the address's creation time.
preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
- if prefixState.preferredUntil != (tcpip.MonotonicTime{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
- preferredUntil = prefixState.preferredUntil
+ if prefixState.preferredUntil != nil && prefixState.preferredUntil.Before(preferredUntil) {
+ preferredUntil = *prefixState.preferredUntil
}
// If the address is no longer preferred, deprecate it immediately.
// Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
tempAddrState.deprecationJob.Cancel()
+
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
} else {
@@ -1679,12 +1718,12 @@ func (ndp *ndpState) cleanupState() {
panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
}
- for router := range ndp.defaultRouters {
- ndp.invalidateDefaultRouter(router)
+ for route := range ndp.offLinkRoutes {
+ ndp.invalidateOffLinkRoute(route)
}
- if got := len(ndp.defaultRouters); got != 0 {
- panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ if got := len(ndp.offLinkRoutes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got))
}
ndp.dhcpv6Configuration = 0
@@ -1847,21 +1886,19 @@ func (ndp *ndpState) stopSolicitingRouters() {
}
func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
- if ndp.defaultRouters != nil {
+ if ndp.offLinkRoutes != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions)
- ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
- if MaxDesyncFactor != 0 {
- ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor)))
- }
+ ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor)))
}
func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 3438deb79..8297a7e10 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -42,24 +42,21 @@ type testNDPDispatcher struct {
func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) {
t.addr = addr
- return true
}
-func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) {
t.addr = addr
}
-func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
}
-func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return false
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
@@ -96,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
ipv6EP := ep.(*endpoint)
ipv6EP.mu.Lock()
- ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference)
ipv6EP.mu.Unlock()
if ndpDisp.addr != lladdr1 {
@@ -147,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
@@ -409,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -605,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
@@ -834,8 +843,12 @@ func TestNDPValidation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
@@ -965,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
@@ -1286,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) {
checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
))
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
checkDADMsg()
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 1b96b1fb8..26640b7ee 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
+ addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackIPv4Addr,
+ PrefixLen: defaultIPv4PrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, clock
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index b7f6d52ae..fe98a52af 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -12,6 +12,7 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 854d6a6ba..fb8ef1ee2 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -122,7 +123,7 @@ type deviceToDest map[tcpip.NICID]destToCounter
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
-func (dd deviceToDest) isAvailable(res Reservation) bool {
+func (dd deviceToDest) isAvailable(res Reservation, portSpecified bool) bool {
flagBits := res.Flags.Bits()
if res.BindToDevice == 0 {
intersection := FlagMask
@@ -138,6 +139,9 @@ func (dd deviceToDest) isAvailable(res Reservation) bool {
return false
}
}
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
return true
}
@@ -146,16 +150,26 @@ func (dd deviceToDest) isAvailable(res Reservation) bool {
if dests, ok := dd[0]; ok {
var count int
intersection, count = dests.intersectionFlags(res)
- if count > 0 && intersection&flagBits == 0 {
- return false
+ if count > 0 {
+ if intersection&flagBits == 0 {
+ return false
+ }
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
}
}
if dests, ok := dd[res.BindToDevice]; ok {
flags, count := dests.intersectionFlags(res)
intersection &= flags
- if count > 0 && intersection&flagBits == 0 {
- return false
+ if count > 0 {
+ if intersection&flagBits == 0 {
+ return false
+ }
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
}
}
@@ -168,12 +182,12 @@ type addrToDevice map[tcpip.Address]deviceToDest
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (ad addrToDevice) isAvailable(res Reservation) bool {
+func (ad addrToDevice) isAvailable(res Reservation, portSpecified bool) bool {
if res.Addr == anyIPAddress {
// If binding to the "any" address then check that there are no
// conflicts with all addresses.
for _, devices := range ad {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
@@ -182,14 +196,14 @@ func (ad addrToDevice) isAvailable(res Reservation) bool {
// Check that there is no conflict with the "any" address.
if devices, ok := ad[anyIPAddress]; ok {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
// Check that this is no conflict with the provided address.
if devices, ok := ad[res.Addr]; ok {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
@@ -310,7 +324,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// If a port is specified, just try to reserve it for all network
// protocols.
if res.Port != 0 {
- if !pm.reserveSpecificPortLocked(res) {
+ if !pm.reserveSpecificPortLocked(res, true /* portSpecified */) {
return 0, &tcpip.ErrPortInUse{}
}
if testPort != nil {
@@ -330,7 +344,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// A port wasn't specified, so try to find one.
return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
- if !pm.reserveSpecificPortLocked(res) {
+ if !pm.reserveSpecificPortLocked(res, false /* portSpecified */) {
return false, nil
}
if testPort != nil {
@@ -350,12 +364,12 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// reserveSpecificPortLocked tries to reserve the given port on all given
// protocols.
-func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool {
+func (pm *PortManager) reserveSpecificPortLocked(res Reservation, portSpecified bool) bool {
// Make sure the port is available.
for _, network := range res.Networks {
desc := portDescriptor{network, res.Transport, res.Port}
if addrs, ok := pm.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(res) {
+ if !addrs.isAvailable(res, portSpecified) {
return false
}
}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index b9a24ff56..05b879543 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
@@ -145,8 +146,12 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Add default route.
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index ef1bfc186..a72afadda 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
@@ -123,13 +124,13 @@ func main() {
log.Fatalf("Bad IP address: %v", addrName)
}
- var addr tcpip.Address
+ var addrWithPrefix tcpip.AddressWithPrefix
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
- addr = tcpip.Address(parsedAddr.To4())
+ addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix()
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
- addr = tcpip.Address(parsedAddr.To16())
+ addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix()
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", addrName)
@@ -175,11 +176,15 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, proto, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address))))
if err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 0ea85f9ed..34ac62444 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -15,17 +15,12 @@
package tcpip
import (
- "math"
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sync"
)
-// PacketOverheadFactor is used to multiply the value provided by the user on a
-// SetSockOpt for setting the send/receive buffer sizes sockets.
-const PacketOverheadFactor = 2
-
// SocketOptionsHandler holds methods that help define endpoint specific
// behavior for socket level socket options. These must be implemented by
// endpoints to get notified when socket level options are set.
@@ -60,8 +55,13 @@ type SocketOptionsHandler interface {
// buffer size. It also returns the newly set value.
OnSetSendBufferSize(v int64) (newSz int64)
- // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE.
+ // OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE.
OnSetReceiveBufferSize(v, oldSz int64) (newSz int64)
+
+ // WakeupWriters is invoked when the send buffer size for an endpoint is
+ // changed. The handler notifies the writers if the send buffer size is
+ // increased with setsockopt(2) for TCP endpoints.
+ WakeupWriters()
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -103,6 +103,9 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
return v
}
+// WakeupWriters implements SocketOptionsHandler.WakeupWriters.
+func (*DefaultSocketOptionsHandler) WakeupWriters() {}
+
// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize.
func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) {
return v
@@ -617,39 +620,23 @@ func (so *SocketOptions) GetSendBufferSize() int64 {
return so.sendBufferSize.Load()
}
+// SendBufferLimits returns the [min, max) range of allowable send buffer
+// sizes.
+func (so *SocketOptions) SendBufferLimits() (min, max int64) {
+ limits := so.getSendBufferLimits(so.stackHandler)
+ return int64(limits.Min), int64(limits.Max)
+}
+
// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
// stack handler should be invoked to set the send buffer size.
func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
- v := sendBufferSize
-
- if !notify {
- so.sendBufferSize.Store(v)
- return
- }
-
- // Make sure the send buffer size is within the min and max
- // allowed.
- ss := so.getSendBufferLimits(so.stackHandler)
- min := int64(ss.Min)
- max := int64(ss.Max)
- // Validate the send buffer size with min and max values.
- // Multiply it by factor of 2.
- if v > max {
- v = max
+ if notify {
+ sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize)
}
-
- if v < math.MaxInt32/PacketOverheadFactor {
- v *= PacketOverheadFactor
- if v < min {
- v = min
- }
- } else {
- v = math.MaxInt32
+ so.sendBufferSize.Store(sendBufferSize)
+ if notify {
+ so.handler.WakeupWriters()
}
-
- // Notify endpoint about change in buffer size.
- newSz := so.handler.OnSetSendBufferSize(v)
- so.sendBufferSize.Store(newSz)
}
// GetReceiveBufferSize gets value for SO_RCVBUF option.
@@ -657,36 +644,19 @@ func (so *SocketOptions) GetReceiveBufferSize() int64 {
return so.receiveBufferSize.Load()
}
-// SetReceiveBufferSize sets value for SO_RCVBUF option.
-func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
- if !notify {
- so.receiveBufferSize.Store(receiveBufferSize)
- return
- }
-
- // Make sure the send buffer size is within the min and max
- // allowed.
- v := receiveBufferSize
- ss := so.getReceiveBufferLimits(so.stackHandler)
- min := int64(ss.Min)
- max := int64(ss.Max)
- // Validate the send buffer size with min and max values.
- if v > max {
- v = max
- }
+// ReceiveBufferLimits returns the [min, max) range of allowable receive buffer
+// sizes.
+func (so *SocketOptions) ReceiveBufferLimits() (min, max int64) {
+ limits := so.getReceiveBufferLimits(so.stackHandler)
+ return int64(limits.Min), int64(limits.Max)
+}
- // Multiply it by factor of 2.
- if v < math.MaxInt32/PacketOverheadFactor {
- v *= PacketOverheadFactor
- if v < min {
- v = min
- }
- } else {
- v = math.MaxInt32
+// SetReceiveBufferSize sets the value of the SO_RCVBUF option, optionally
+// notifying the owning endpoint.
+func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
+ if notify {
+ oldSz := so.receiveBufferSize.Load()
+ receiveBufferSize = so.handler.OnSetReceiveBufferSize(receiveBufferSize, oldSz)
}
-
- oldSz := so.receiveBufferSize.Load()
- // Notify endpoint about change in buffer size.
- newSz := so.handler.OnSetReceiveBufferSize(v, oldSz)
- so.receiveBufferSize.Store(newSz)
+ so.receiveBufferSize.Store(receiveBufferSize)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 395ff9a07..6c42ab29b 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -85,6 +85,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/transport/tcpconntrack",
@@ -95,7 +96,7 @@ go_library(
go_test(
name = "stack_x_test",
- size = "medium",
+ size = "small",
srcs = [
"addressable_endpoint_state_test.go",
"ndp_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ce9cebdaa..7e4b5bf74 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We now promote the address.
for i, s := range a.mu.primary {
if s == addrState {
- switch peb {
+ switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
@@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
case NeverPrimaryEndpoint:
a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
@@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// or we are adding a new temporary or permanent address.
//
// The address MUST be write locked at this point.
- defer addrState.mu.Unlock()
+ defer addrState.mu.Unlock() // +checklocksforce
if permanent {
if addrState.mu.kind.IsPermanent() {
@@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// Acquire the address before returning it.
addrState.mu.refs++
- addrState.mu.deprecated = deprecated
- addrState.mu.configType = configType
+ addrState.mu.deprecated = properties.Deprecated
+ addrState.mu.configType = properties.ConfigType
if attemptAddToPrimary {
- switch peb {
+ switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.mu.primary = append(a.mu.primary, addrState)
@@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
a.mu.primary[0] = addrState
}
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
@@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
- ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
- panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 140f146f6..c55f85743 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
}
{
- ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
- t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err)
}
// We don't need the address endpoint.
ep.DecRef()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index f7fbcbaa7..068dab7ce 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -35,7 +35,6 @@ import (
// Currently, only TCP tracking is supported.
// Our hash table has 16K buckets.
-// TODO(gvisor.dev/issue/170): These should be tunable.
const numBuckets = 1 << 14
// Direction of the tuple.
@@ -165,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
- // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
- // other tcp states.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
} else if hook == cn.tcbHook {
@@ -246,8 +243,7 @@ func (ct *ConnTrack) init() {
// connFor gets the conn for pkt if it exists, or returns nil
// if it does not. It returns an error when pkt does not contain a valid TCP
// header.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
-// other transport protocols.
+// TODO(gvisor.dev/issue/6168): Support UDP.
func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
tid, err := packetToTupleID(pkt)
if err != nil {
@@ -367,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Unlocking can happen in any order.
ct.buckets[tupleBucket].mu.Unlock()
if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
}
@@ -385,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ // TODO(gvisor.dev/issue/6168): Support UDP.
if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -409,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
+ var newAddr tcpip.Address
+ var newPort uint16
+
+ updateSRCFields := false
+
switch hook {
case Prerouting, Output:
if conn.manip == manipDestination {
switch dir {
case dirOriginal:
- tcpHeader.SetDestinationPort(conn.reply.srcPort)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
case dirReply:
- tcpHeader.SetSourcePort(conn.original.dstPort)
- netHeader.SetSourceAddress(conn.original.dstAddr)
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+
+ updateSRCFields = true
}
pkt.NatDone = true
}
@@ -426,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
if conn.manip == manipSource {
switch dir {
case dirOriginal:
- tcpHeader.SetSourcePort(conn.reply.dstPort)
- netHeader.SetSourceAddress(conn.reply.dstAddr)
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+
+ updateSRCFields = true
case dirReply:
- tcpHeader.SetDestinationPort(conn.original.srcPort)
- netHeader.SetDestinationAddress(conn.original.srcAddr)
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
}
pkt.NatDone = true
}
@@ -441,33 +446,33 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
+ fullChecksum := false
+ updatePseudoHeader := false
switch hook {
case Prerouting, Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
+ updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ fullChecksum = true
+ updatePseudoHeader = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
+ rewritePacket(
+ netHeader,
+ tcpHeader,
+ updateSRCFields,
+ fullChecksum,
+ updatePseudoHeader,
+ newPort,
+ newAddr,
+ )
// Update the state of tcb.
- // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
- // other tcp states.
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -544,8 +549,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused returns the next bucket that should be checked and the time after
// which it should be called again.
func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
- // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
- // it is in Linux via sysctl.
const fractionPerReaping = 128
const maxExpiredPct = 50
const maxFullTraversal = 60 * time.Second
@@ -623,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
// Don't re-unlock if both tuples are in the same bucket.
if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
return true
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 72f66441f..c2f1f4798 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
- return fwdTestNetDefaultPrefixLen
-}
-
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
@@ -342,6 +338,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p
return n, nil
}
+func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
@@ -380,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
- if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
// NIC 2 has the link address "b", and added the network address 2.
@@ -393,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
- if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
nic, ok := s.nics[2]
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 0a26f6dd8..f152c0d83 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -268,10 +268,6 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
-// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from
-// which address can be gathered. Currently, address is only needed for
-// prerouting.
-//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 2812c89aa..96cc899bb 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
// forwarded).
-//
-// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
-// them.
type RedirectTarget struct {
// Port indicates port used to redirect. It is immutable.
Port uint16
@@ -100,9 +97,6 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for Prerouting and calls pkt.Clone(), neither
-// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
@@ -136,34 +130,26 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
panic("redirect target is supported only on output and prerouting hooks")
}
- // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
- // we need to change dest address (for OUTPUT chain) or ports.
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.Port)
- // Calculate UDP checksum and set it.
if hook == Output {
- udpHeader.SetChecksum(0)
- netHeader := pkt.Network()
- netHeader.SetDestinationAddress(address)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ udpHeader,
+ false, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ rt.Port,
+ address,
+ )
+ } else {
+ udpHeader.SetDestinationPort(rt.Port)
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -222,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetChecksum(0)
- udpHeader.SetSourcePort(st.Port)
- netHeader := pkt.Network()
- netHeader.SetSourceAddress(st.Addr)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ header.UDP(pkt.TransportHeader().View()),
+ true, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ st.Port,
+ st.Addr,
+ )
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -260,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleAccept, 0
}
+
+func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
+ if updateSRCFields {
+ if fullChecksum {
+ t.SetSourcePortWithChecksumUpdate(newPort)
+ } else {
+ t.SetSourcePort(newPort)
+ }
+ } else {
+ if fullChecksum {
+ t.SetDestinationPortWithChecksumUpdate(newPort)
+ } else {
+ t.SetDestinationPort(newPort)
+ }
+ }
+
+ if updatePseudoHeader {
+ var oldAddr tcpip.Address
+ if updateSRCFields {
+ oldAddr = n.SourceAddress()
+ } else {
+ oldAddr = n.DestinationAddress()
+ }
+
+ t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum)
+ }
+
+ if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok {
+ if updateSRCFields {
+ checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr)
+ } else {
+ checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr)
+ }
+ } else if updateSRCFields {
+ n.SetSourceAddress(newAddr)
+ } else {
+ n.SetDestinationAddress(newAddr)
+ }
+}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 93592e7f5..66e5f22ac 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -242,7 +242,6 @@ type IPHeaderFilter struct {
func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
- // TODO(gvisor.dev/issue/170): Support other filter fields.
transProto tcpip.TransportProtocolNumber
dstAddr tcpip.Address
srcAddr tcpip.Address
@@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa
return true
case Postrouting:
- // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING.
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 133bacdd0..40b33b6b5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -52,17 +52,6 @@ const (
linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
defaultPrefixLen = 128
-
- // Extra time to use when waiting for an async event to occur.
- defaultAsyncPositiveEventTimeout = 10 * time.Second
-
- // Extra time to use when waiting for an async event to not occur.
- //
- // Since a negative check is used to make sure an event did not happen, it is
- // okay to use a smaller timeout compared to the positive case since execution
- // stall in regards to the monotonic clock will not affect the expected
- // outcome.
- defaultAsyncNegativeEventTimeout = time.Second
)
var (
@@ -112,11 +101,13 @@ type ndpDADEvent struct {
res stack.DADResult
}
-type ndpRouterEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- // true if router was discovered, false if invalidated.
- discovered bool
+type ndpOffLinkRouteEvent struct {
+ nicID tcpip.NICID
+ subnet tcpip.Subnet
+ router tcpip.Address
+ prf header.NDPRoutePreference
+ // true if route was updated, false if invalidated.
+ updated bool
}
type ndpPrefixEvent struct {
@@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct {
eventType ndpAutoGenAddrEventType
}
+func (e ndpAutoGenAddrEvent) String() string {
+ return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType)
+}
+
type ndpRDNSS struct {
addrs []tcpip.Address
lifetime time.Duration
@@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
// related events happen for test purposes.
type ndpDispatcher struct {
dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
+ offLinkRouteC chan ndpOffLinkRouteEvent
prefixC chan ndpPrefixEvent
- rememberPrefix bool
autoGenAddrC chan ndpAutoGenAddrEvent
rdnssC chan ndpRDNSSEvent
dnsslC chan ndpDNSSLEvent
@@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add
}
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered.
-func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated.
+func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) {
+ if c := n.offLinkRouteC; c != nil {
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
+ prf,
true,
}
}
-
- return n.rememberRouter
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated.
-func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated.
+func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) {
+ if c := n.offLinkRouteC; c != nil {
+ var prf header.NDPRoutePreference
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
+ prf,
false,
}
}
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
-func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
nicID,
@@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip
true,
}
}
-
- return n.rememberPrefix
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
@@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi
}
}
-func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
if c := n.autoGenAddrC; c != nil {
c <- ndpAutoGenAddrEvent{
nicID,
@@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi
newAddr,
}
}
- return true
}
func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
@@ -340,8 +333,12 @@ func TestDADDisabled(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Should get the address immediately since we should not have performed
@@ -386,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -524,8 +524,12 @@ func TestDADResolve(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Make sure the address does not resolve before the resolution time has
@@ -747,8 +751,12 @@ func TestDADFail(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet
@@ -785,8 +793,8 @@ func TestDADFail(t *testing.T) {
// Attempting to add the address again should not fail if the address's
// state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
})
}
@@ -858,8 +866,12 @@ func TestDADStop(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -982,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) {
// Add addresses for each NIC.
addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix1,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err)
}
addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix2,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err)
}
expectDADEvent(nicID2, addr2)
addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix3,
+ }
+ if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err)
}
expectDADEvent(nicID3, addr3)
@@ -1039,9 +1063,12 @@ func TestSetNDPConfigurations(t *testing.T) {
}
}
-// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
-// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+// raBuf returns a valid NDP Router Advertisement with options, router
+// preference and DHCPv6 configurations specified.
+func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+ const flagsByte = 1
+ const routerLifetimeOffset = 2
+
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -1050,19 +1077,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
- binary.BigEndian.PutUint16(raPayload[2:], rl)
+ binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl)
// Populate the Managed Address flag field.
if managedAddress {
- // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
- // of the RA payload.
- raPayload[1] |= 1 << 7
+ // The Managed Addresses flag field is the 7th bit of the flags byte.
+ raPayload[flagsByte] |= 1 << 7
}
// Populate the Other Configurations flag field.
if otherConfigurations {
- // The Other Configurations flag field is the 6th bit of byte #1
- // (0-indexing) of the RA payload.
- raPayload[1] |= 1 << 6
+ // The Other Configurations flag field is the 6th bit of the flags byte.
+ raPayload[flagsByte] |= 1 << 6
}
+ // The Prf field is held in the flags byte.
+ raPayload[flagsByte] |= byte(prf) << 3
opts := ra.Options()
opts.Serialize(optSer)
pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -1090,7 +1117,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+ return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer)
}
// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
@@ -1098,18 +1125,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer {
+ return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{})
}
// raBuf returns a valid NDP Router Advertisement.
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
+func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
+// raBufWithPrf returns a valid NDP Router Advertisement with a preference.
+//
+// Note, raBufWithPrf does not populate any of the RA fields other than the
+// Router Lifetime and Default Router Preference fields.
+func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{})
+}
+
// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
// Information option.
//
@@ -1148,6 +1183,39 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
})
}
+// raBufWithRIO returns a valid NDP Router Advertisement with a single Route
+// Information option.
+//
+// All fields in the RA will be zero except the RIO option.
+func raBufWithRIO(t *testing.T, ip tcpip.Address, prefix tcpip.AddressWithPrefix, lifetimeSeconds uint32, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ // buf will hold the route information option after the Type and Length
+ // fields.
+ //
+ // 2.3. Route Information Option
+ //
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Type | Length | Prefix Length |Resvd|Prf|Resvd|
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Route Lifetime |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Prefix (Variable Length) |
+ // . .
+ // . .
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ var buf [22]byte
+ buf[0] = uint8(prefix.PrefixLen)
+ buf[1] = byte(prf) << 3
+ binary.BigEndian.PutUint32(buf[2:], lifetimeSeconds)
+ if n := copy(buf[6:], prefix.Address); n != len(prefix.Address) {
+ t.Fatalf("got copy(...) = %d, want = %d", n, len(prefix.Address))
+ }
+ return raBufWithOpts(ip, 0 /* router lifetime */, header.NDPOptionsSerializer{
+ header.NDPRouteInformation(buf[:]),
+ })
+}
+
func TestDynamicConfigurationsDisabled(t *testing.T) {
const (
nicID = 1
@@ -1169,7 +1237,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
config: func(enable bool) ipv6.NDPConfigurations {
return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable}
},
- ra: raBuf(llAddr2, 1000),
+ ra: raBufSimple(llAddr2, 1000),
},
{
name: "No Prefix Discovery",
@@ -1205,9 +1273,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- prefixC: make(chan ndpPrefixEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
ndpConfigs := test.config(enable)
ndpConfigs.HandleRAs = handle
@@ -1277,8 +1345,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpectedly discovered a router when configured not to: %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e)
default:
}
select {
@@ -1304,10 +1372,8 @@ func boolToUint64(v bool) uint64 {
return 0
}
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// discovered flag set to discovered.
-func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
- return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string {
+ return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: subnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e))
}
func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
@@ -1340,167 +1406,176 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b
}
}
-// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered router when the dispatcher asks it not to.
-func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
+func TestOffLinkRouteDiscovery(t *testing.T) {
+ const nicID = 1
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ moreSpecificPrefix := tcpip.AddressWithPrefix{Address: testutil.MustParse6("a00::"), PrefixLen: 16}
+ tests := []struct {
+ name string
- // Receive an RA for a router we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
+ discoverDefaultRouters bool
+ discoverMoreSpecificRoutes bool
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the router in the first place.
- clock.Advance(lifetimeSeconds * time.Second)
- select {
- case <-ndpDisp.routerC:
- t.Fatal("should not have received any router events")
- default:
+ dest tcpip.Subnet
+ ra func(*testing.T, tcpip.Address, uint16, header.NDPRoutePreference) *stack.PacketBuffer
+ }{
+ {
+ name: "Default router discovery",
+ discoverDefaultRouters: true,
+ discoverMoreSpecificRoutes: false,
+ dest: header.IPv6EmptySubnet,
+ ra: func(_ *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBufWithPrf(router, lifetimeSeconds, prf)
+ },
+ },
+ {
+ name: "More-specific route discovery",
+ discoverDefaultRouters: false,
+ discoverMoreSpecificRoutes: true,
+ dest: moreSpecificPrefix.Subnet(),
+ ra: func(t *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBufWithRIO(t, router, moreSpecificPrefix, uint32(lifetimeSeconds), prf)
+ },
+ },
}
-}
-func TestRouterDiscovery(t *testing.T) {
- testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handleRAs,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverDefaultRouters: test.discoverDefaultRouters,
+ DiscoverMoreSpecificRoutes: test.discoverMoreSpecificRoutes,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
- t.Helper()
+ expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, updated); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
}
- default:
- t.Fatal("expected router discovery event")
- }
- }
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
+ expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
- clock.Advance(timeout)
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ clock.Advance(timeout)
+ select {
+ case e := <-ndpDisp.offLinkRouteC:
+ var prf header.NDPRoutePreference
+ if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, false); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("timed out waiting for router discovery event")
+ }
}
- default:
- t.Fatal("timed out waiting for router discovery event")
- }
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
- default:
- }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Rx an RA from lladdr2 with a huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with 0 lifetime")
+ default:
+ }
- // Rx an RA from another router (lladdr3) with non-zero lifetime.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
+ // Discover an off-link route through llAddr2.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.ReservedRoutePreference))
+ if test.discoverMoreSpecificRoutes {
+ // The reserved value is considered invalid with more-specific route
+ // discovery so we inject the same packet but with the default
+ // (medium) preference value.
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with a reserved preference value")
+ default:
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
+ }
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
+
+ // Rx an RA from another router (lladdr3) with non-zero lifetime and
+ // non-default preference value.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr3, l3LifetimeSeconds, header.HighRoutePreference))
+ expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true)
+
+ // Rx an RA from lladdr2 with lesser lifetime and default (medium)
+ // preference value.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.MediumRoutePreference))
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("should not receive a off-link route event when updating lifetimes for known routers")
+ default:
+ }
- // Rx an RA from lladdr2 with lesser lifetime.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
- default:
- }
+ // Rx an RA from lladdr2 with a different preference.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.LowRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true)
- // Wait for lladdr2's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
-
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
-
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
-
- // Wait for lladdr3's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
- })
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false)
+
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
+ })
+ })
+ }
}
// TestRouterDiscoveryMaxRouters tests that only
-// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
+// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ const nicID = 1
+
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -1513,23 +1588,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
})},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ {
+ for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
linkAddr[5] = byte(i)
llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5))
- if i <= ipv6.MaxDiscoveredDefaultRouters {
+ if i <= ipv6.MaxDiscoveredOffLinkRoutes {
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr, header.MediumRoutePreference, true); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected router discovery event")
@@ -1537,7 +1612,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
} else {
select {
- case <-ndpDisp.routerC:
+ case <-ndpDisp.offLinkRouteC:
t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
default:
}
@@ -1551,54 +1626,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st
return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
}
-// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered on-link prefix when the dispatcher asks it not to.
-func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- prefix, subnet, _ := prefixSubnetAddr(0, "")
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with prefix that we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, true); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the prefix in the first place.
- clock.Advance(lifetimeSeconds * time.Second)
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("should not have received any prefix events")
- default:
- }
-}
-
func TestPrefixDiscovery(t *testing.T) {
prefix1, subnet1, _ := prefixSubnetAddr(0, "")
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
@@ -1606,8 +1633,7 @@ func TestPrefixDiscovery(t *testing.T) {
testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
clock := faketime.NewManualClock()
@@ -1697,17 +1723,6 @@ func TestPrefixDiscovery(t *testing.T) {
}
func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
- // Update the infinite lifetime value to a smaller value so we can test
- // that when we receive a PI with such a lifetime value, we do not
- // invalidate the prefix.
- const testInfiniteLifetimeSeconds = 2
- const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
- saved := header.NDPInfiniteLifetime
- header.NDPInfiniteLifetime = testInfiniteLifetime
- defer func() {
- header.NDPInfiniteLifetime = saved
- }()
-
prefix := tcpip.AddressWithPrefix{
Address: testutil.MustParse6("102:304:506:708::"),
PrefixLen: 64,
@@ -1715,8 +1730,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
subnet := prefix.Subnet()
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
clock := faketime.NewManualClock()
@@ -1750,9 +1764,9 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with prefix in an NDP Prefix Information option (PI)
// with infinite valid lifetime which should not get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
expectPrefixEvent(subnet, true)
- clock.Advance(testInfiniteLifetime)
+ clock.Advance(header.NDPInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
@@ -1760,9 +1774,8 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
// Receive an RA with finite lifetime.
- // The prefix should get invalidated after 1s.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
- clock.Advance(testInfiniteLifetime)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
+ clock.Advance(header.NDPInfiniteLifetime - time.Second)
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, false); diff != "" {
@@ -1773,23 +1786,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
// Receive an RA with finite lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
expectPrefixEvent(subnet, true)
// Receive an RA with prefix with an infinite lifetime.
// The prefix should not be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
- clock.Advance(testInfiniteLifetime)
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- default:
- }
-
- // Receive an RA with a prefix with a lifetime value greater than the
- // set infinite lifetime value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
- clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
+ clock.Advance(header.NDPInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
@@ -1806,8 +1809,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -1884,17 +1886,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
}
+const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second)
+const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second)
+
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
func TestAutoGenAddr(t *testing.T) {
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
-
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1903,6 +1900,7 @@ func TestAutoGenAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1911,6 +1909,7 @@ func TestAutoGenAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
@@ -1960,8 +1959,9 @@ func TestAutoGenAddr(t *testing.T) {
default:
}
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
+ // the minimum.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0))
expectAutoGenAddrEvent(addr2, newAddr)
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -1971,7 +1971,7 @@ func TestAutoGenAddr(t *testing.T) {
}
// Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
@@ -1979,12 +1979,13 @@ func TestAutoGenAddr(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -2014,20 +2015,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t
// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when
// configured to do so as part of IPv6 Privacy Extensions.
func TestAutoGenTempAddr(t *testing.T) {
- const (
- nicID = 1
- newMinVL = 5
- newMinVLDuration = newMinVL * time.Second
- )
-
- savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate
- savedMaxDesync := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
- ipv6.MaxDesyncFactor = savedMaxDesync
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- ipv6.MaxDesyncFactor = time.Nanosecond
+ const nicID = 1
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -2047,218 +2035,211 @@ func TestAutoGenTempAddr(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for i, test := range tests {
- i := i
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- seed := []byte{uint8(i)}
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], seed, nicID)
- newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
- return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
- }
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 2),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
- },
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- TempIIDSeed: seed,
- })},
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ for i, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ seed := []byte{uint8(i)}
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], seed, nicID)
+ newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 2),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })},
+ Clock: clock,
+ })
- expectDADEventAsync := func(addr tcpip.Address) {
- t.Helper()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD event")
- }
- }
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("expected addr auto gen event")
}
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectDADEventAsync(addr1.Address)
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for addr auto gen event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid & preferred lifetimes.
- tempAddr1 := newTempAddr(addr1.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- expectDADEventAsync(tempAddr1.Address)
- if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for DAD event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ }
- // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred
- // lifetimes.
- tempAddr2 := newTempAddr(addr2.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- expectDADEventAsync(addr2.Address)
- expectAutoGenAddrEventAsync(tempAddr2, newAddr)
- expectDADEventAsync(tempAddr2.Address)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ default:
+ }
- // Deprecate prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectDADEventAsync(addr1.Address)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Refresh lifetimes for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ tempAddr1 := newTempAddr(addr1.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ expectDADEventAsync(tempAddr1.Address)
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Reduce valid lifetime and deprecate addresses of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Wait for addrs of prefix1 to be invalidated. They should be
- // invalidated at the same time.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- var nextAddr tcpip.AddressWithPrefix
- if e.addr == addr1 {
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = tempAddr1
- } else {
- if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = addr1
- }
+ // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
+ // the minimum and won't be reached in this test.
+ tempAddr2 := newTempAddr(addr2.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ expectDADEventAsync(addr2.Address)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr)
+ expectDADEventAsync(tempAddr2.Address)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
+ // Deprecate prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Refresh lifetimes for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Reduce valid lifetime and deprecate addresses of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for addrs of prefix1 to be invalidated. They should be
+ // invalidated at the same time.
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ var nextAddr tcpip.AddressWithPrefix
+ if e.addr == addr1 {
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
+ nextAddr = tempAddr1
+ } else {
+ if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = addr1
}
- // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unexpected auto gen addr event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for addr auto gen event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
- }
- })
- }
- })
+ default:
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ })
+ }
}
// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not
@@ -2266,12 +2247,6 @@ func TestAutoGenTempAddr(t *testing.T) {
func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
const nicID = 1
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- }()
- ipv6.MaxDesyncFactor = time.Nanosecond
-
tests := []struct {
name string
dupAddrTransmits uint8
@@ -2287,66 +2262,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- AutoGenLinkLocal: true,
- })},
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
+ })},
+ Clock: clock,
+ })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // The stable link-local address should auto-generate and resolve DAD.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
+ // The stable link-local address should auto-generate and resolve DAD.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD event")
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("timed out waiting for DAD event")
+ }
- // No new addresses should be generated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unxpected auto gen addr event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
- }
- })
- }
- })
+ // No new addresses should be generated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unxpected auto gen addr event = %+v", e)
+ default:
+ }
+ })
+ }
}
// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address
@@ -2359,12 +2324,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
retransmitTimer = 2 * time.Second
)
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- }()
- ipv6.MaxDesyncFactor = 0
-
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
@@ -2375,6 +2334,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
@@ -2388,6 +2348,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2417,12 +2378,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// Wait for DAD to complete for the stable address then expect the temporary
// address to be generated.
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
select {
@@ -2430,7 +2392,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2439,46 +2401,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// regenerated.
func TestAutoGenTempAddrRegen(t *testing.T) {
const (
- nicID = 1
- regenAfter = 2 * time.Second
- newMinVL = 10
- newMinVLDuration = newMinVL * time.Second
- )
+ nicID = 1
+ regenAdv = 2 * time.Second
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
- }()
- ipv6.MaxDesyncFactor = 0
- ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
+ numTempAddrs = 3
+ maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate
+ )
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
+ for i := 0; i < len(tempAddrs); i++ {
+ tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ }
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: regenAdv,
+ MaxTempAddrValidLifetime: maxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ }
+ clock := faketime.NewManualClock()
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ndpConfigs,
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
+ RandSource: &randSource,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2501,36 +2461,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
+ tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
+ effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor
+ // The time since the last regeneration before a new temporary address is
+ // generated.
+ tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
+
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ expectAutoGenAddrEvent(tempAddrs[0], newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
+ expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
+ expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv)
+ expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
// Stop generating temporary addresses
ndpConfigs.AutoGenTempGlobalAddresses = false
@@ -2541,45 +2508,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
ndpEP.SetNDPConfigurations(ndpConfigs)
}
+ // Refresh lifetimes and wait for the last temporary address to be deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
+ expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv)
+
+ // Refresh lifetimes such that the prefix is valid and preferred forever.
+ //
+ // This should not affect the lifetimes of temporary addresses because they
+ // are capped by the maximum valid and preferred lifetimes for temporary
+ // addresses.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
+
// Wait for all the temporary addresses to get invalidated.
- tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3}
- invalidateAfter := newMinVLDuration - 2*regenAfter
+ invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{})
for _, addr := range tempAddrs {
- // Wait for a deprecation then invalidation event, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation jobs could execute in any
- // order.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we shouldn't get a deprecation
- // event after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
- }
- } else {
- t.Fatalf("got unexpected auto-generated event = %+v", e)
- }
- case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
-
- invalidateAfter = regenAfter
+ expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter)
+ invalidateAfter = tempAddrRegenenerationTime
}
- if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" {
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" {
t.Fatal(mismatch)
}
}
@@ -2588,52 +2534,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
// regeneration job gets updated when refreshing the address's lifetimes.
func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
- nicID = 1
- regenAfter = 2 * time.Second
- newMinVL = 10
- newMinVLDuration = newMinVL * time.Second
- )
+ nicID = 1
+ regenAdv = 2 * time.Second
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
- }()
- ipv6.MaxDesyncFactor = 0
- ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
+ numTempAddrs = 3
+ maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate
+ maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second)
+ )
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
+ for i := 0; i < len(tempAddrs); i++ {
+ tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ }
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: regenAdv,
+ MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime,
+ MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2,
+ }
+ clock := faketime.NewManualClock()
+ initialTime := clock.NowMonotonic()
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ndpConfigs,
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
+ RandSource: &randSource,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
+
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -2650,22 +2598,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ expectAutoGenAddrEvent(tempAddrs[0], newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -2673,13 +2622,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
//
// A new temporary address should be generated after the regeneration
// time has passed since the prefix is deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0))
expectAutoGenAddrEvent(addr, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ t.Fatalf("unexpected auto gen addr event = %#v", e)
+ default:
+ }
+
+ effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor
+ // The time since the last regeneration before a new temporary address is
+ // generated.
+ tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
+
+ // Advance the clock by the regeneration time but don't expect a new temporary
+ // address as the prefix is deprecated.
+ clock.Advance(tempAddrRegenenerationTime)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %#v", e)
+ default:
}
// Prefer the prefix again.
@@ -2687,8 +2650,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
// - this regeneration does not depend on a job.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEvent(tempAddr2, newAddr)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
+ expectAutoGenAddrEvent(tempAddrs[1], newAddr)
+ // Wait for the first temporary address to be deprecated.
+ expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %s", e)
+ default:
+ }
// Increase the maximum lifetimes for temporary addresses to large values
// then refresh the lifetimes of the prefix.
@@ -2699,34 +2669,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// regenerate a new temporary address. Note, new addresses are only
// regenerated after the preferred lifetime - the regenerate advance duration
// as paased.
- ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
- ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
+ const largeLifetimeSeconds = minVLSeconds * 2
+ const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second
+ ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime
+ ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime
ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
ndpEP.SetNDPConfigurations(ndpConfigs)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime)
+ clock.Advance(largeLifetime - timeSinceInitialTime)
+ expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
+ // to offset the advement of time to test the first temporary address's
+ // deprecation after the second was generated
+ advLess := regenAdv
+ expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv))
+ expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ default:
}
-
- // Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration job gets scheduled again.
- //
- // The maximum lifetime is the sum of the minimum lifetimes for temporary
- // addresses + the time that has already passed since the last address was
- // generated so that the regeneration job is needed to generate the next
- // address.
- newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
- ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
- ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
- ndpEP.SetNDPConfigurations(ndpConfigs)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
@@ -2853,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
continue
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: test.addrs[j].Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{}
@@ -2954,13 +2924,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// stack.Stack will have a default route through the router (llAddr3) installed
// and a static link-address (linkAddr3) added to the link address cache for the
// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
t.Helper()
ndpDisp := &ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -2970,6 +2941,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
NDPDisp: ndpDisp,
})},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2983,7 +2955,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil {
t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err)
}
- return ndpDisp, e, s
+ return ndpDisp, e, s, clock
}
// addrForNewConnectionTo returns the local address used when creating a new
@@ -3057,7 +3029,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3160,19 +3132,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// when its preferred lifetime expires.
func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
-
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3190,12 +3154,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -3213,7 +3178,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
expectAutoGenAddrEvent(addr2, newAddr)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
@@ -3232,7 +3197,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3241,7 +3206,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3251,6 +3216,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// addr2 should be the primary endpoint now since addr1 is deprecated but
// addr2 is not.
expectPrimaryAddr(addr2)
+
// addr1 is deprecated but if explicitly requested, it should be used.
fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
@@ -3259,7 +3225,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
// sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3271,7 +3237,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3281,7 +3247,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3295,7 +3261,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second)
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3305,7 +3271,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr2)
// Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3317,6 +3283,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// cases because the deprecation and invalidation handlers could be handled in
// either deprecation then invalidation, or invalidation then deprecation
// (which should be cancelled by the invalidation handler).
+ //
+ // Since we're about to cause both events to fire, we need the dispatcher
+ // channel to be able to hold both.
+ if got, want := len(ndpDisp.autoGenAddrC), 0; got != want {
+ t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
+ }
+ if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want {
+ t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
+ }
+ ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2)
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
@@ -3327,21 +3304,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
+ // If we get an invalidation event first, we should not get a deprecation
// event after.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
} else {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3378,15 +3355,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// infinite values.
func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
const infiniteVLSeconds = 2
- const minVLSeconds = 1
- savedIL := header.NDPInfiniteLifetime
- savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
- header.NDPInfiniteLifetime = savedIL
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
- header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3410,68 +3378,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- // Receive an RA with finite prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- default:
- t.Fatal("expected addr auto gen event")
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- // Receive an new RA with prefix with infinite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
- // Receive a new RA with prefix with finite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- })
- }
- })
+
+ default:
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
}
// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
@@ -3479,12 +3437,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
// RFC 4862 section 5.5.3.e.
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
- const newMinVL = 4
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3495,137 +3447,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
evl uint32
}{
// Should update the VL to the minimum VL for updating if the
- // new VL is less than newMinVL but was originally greater than
+ // new VL is less than minVLSeconds but was originally greater than
// it.
{
"LargeVLToVLLessThanMinVLForUpdate",
9999,
1,
- newMinVL,
+ minVLSeconds,
},
{
"LargeVLTo0",
9999,
0,
- newMinVL,
+ minVLSeconds,
},
{
"InfiniteVLToVLLessThanMinVLForUpdate",
infiniteVL,
1,
- newMinVL,
+ minVLSeconds,
},
{
"InfiniteVLTo0",
infiniteVL,
0,
- newMinVL,
+ minVLSeconds,
},
- // Should not update VL if original VL was less than newMinVL
- // and the new VL is also less than newMinVL.
+ // Should not update VL if original VL was less than minVLSeconds
+ // and the new VL is also less than minVLSeconds.
{
"ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
- newMinVL - 1,
- newMinVL - 3,
- newMinVL - 1,
+ minVLSeconds - 1,
+ minVLSeconds - 3,
+ minVLSeconds - 1,
},
// Should take the new VL if the new VL is greater than the
- // remaining time or is greater than newMinVL.
+ // remaining time or is greater than minVLSeconds.
{
"MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
- newMinVL + 5,
- newMinVL + 3,
- newMinVL + 3,
+ minVLSeconds + 5,
+ minVLSeconds + 3,
+ minVLSeconds + 3,
},
{
"SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
- newMinVL - 3,
- newMinVL - 1,
- newMinVL - 1,
+ minVLSeconds - 3,
+ minVLSeconds - 1,
+ minVLSeconds - 1,
},
{
"SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
- newMinVL - 3,
- newMinVL + 1,
- newMinVL + 1,
+ minVLSeconds - 3,
+ minVLSeconds + 1,
+ minVLSeconds + 1,
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
- }
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- // Receive an RA with prefix with initial VL,
- // test.ovl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
+ // Receive an RA with prefix with initial VL,
+ // test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
- // Receive an new RA with prefix with new VL,
- // test.nvl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+ // Receive an new RA with prefix with new VL,
+ // test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
- //
- // Validate that the VL for the address got set
- // to test.evl.
- //
+ //
+ // Validate that the VL for the address got set
+ // to test.evl.
+ //
- // The address should not be invalidated until the effective valid
- // lifetime has passed.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
- }
+ // The address should not be invalidated until the effective valid
+ // lifetime has passed.
+ const delta = 1
+ clock.Advance(time.Duration(test.evl)*time.Second - delta)
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ default:
+ }
- // Wait for the invalidation event.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
+ // Wait for the invalidation event.
+ clock.Advance(delta)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- })
- }
- })
+ default:
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
}
// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
@@ -3696,7 +3640,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3735,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addr2,
}
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -3824,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
// Add the address as a static address before SLAAC tries to add it.
- if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err)
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -3976,13 +3922,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
const maxMaxRetries = 3
const lifetimeSeconds = 10
- // Needed for the temporary address sub test.
- savedMaxDesync := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesync
- }()
- ipv6.MaxDesyncFactor = time.Nanosecond
-
secretKey := makeSecretKey(t)
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
@@ -4008,22 +3947,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
- expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
+ expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
@@ -4034,15 +3975,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
+ expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
}
@@ -4053,7 +3995,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
name string
ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
- prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
+ prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
}{
{
@@ -4062,7 +4004,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
- prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
// Receive an RA with prefix1 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
return nil
@@ -4076,7 +4018,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
name: "LinkLocal address",
ndpConfigs: ipv6.NDPConfigurations{},
autoGenLinkLocal: true,
- prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
return nil
},
addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
@@ -4090,14 +4032,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
- prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
header.InitialTempIID(tempIIDHistory, nil, nicID)
// Generate a stable SLAAC address so temporary addresses will be
// generated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
- expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
+ expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
// The stable address will be assigned throughout the test.
return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
@@ -4109,14 +4051,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
for _, addrType := range addrTypes {
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the parallel
- // tests complete and limit the number of parallel tests running at the same
- // time to reduce flakes.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
t.Run(addrType.name, func(t *testing.T) {
for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
@@ -4125,8 +4059,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
addrType := addrType
t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
@@ -4134,6 +4066,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := addrType.ndpConfigs
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: addrType.autoGenLinkLocal,
@@ -4150,6 +4083,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
SecretKey: secretKey,
},
})},
+ Clock: clock,
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4157,12 +4091,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
var tempIIDHistory [header.IIDSize]byte
- stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:])
+ stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:])
// Simulate DAD conflicts so the address is regenerated.
for i := uint8(0); i < numFailures; i++ {
addr := addrType.addrGenFn(i, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+ expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
// Should not have any new addresses assigned to the NIC.
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
@@ -4172,17 +4106,21 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Simulate a DAD conflict.
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
- expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
+ expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
}
- expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{})
+ expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{})
}
// Should not have any new addresses assigned to the NIC.
@@ -4194,8 +4132,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// an address after DAD resolves.
if maxRetries+1 > numFailures {
addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
- expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{})
+ expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
+ expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{})
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -4205,7 +4143,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
})
}
@@ -4718,11 +4656,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
)
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
@@ -4765,17 +4701,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
),
)
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID)
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID)
}
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
@@ -4797,8 +4733,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpected router event = %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpected off-link route event = %#v", e)
default:
}
select {
@@ -4884,11 +4820,9 @@ func TestCleanupNDPState(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents),
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
}
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
@@ -4905,14 +4839,14 @@ func TestCleanupNDPState(t *testing.T) {
Clock: clock,
})
- expectRouterEvent := func() (bool, ndpRouterEvent) {
+ expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) {
select {
- case e := <-ndpDisp.routerC:
+ case e := <-ndpDisp.offLinkRouteC:
return true, e
default:
}
- return false, ndpRouterEvent{}
+ return false, ndpOffLinkRouteEvent{}
}
expectPrefixEvent := func() (bool, ndpPrefixEvent) {
@@ -4957,8 +4891,8 @@ func TestCleanupNDPState(t *testing.T) {
// multiple addresses.
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
@@ -4968,8 +4902,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
@@ -4979,8 +4913,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
@@ -4990,8 +4924,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
@@ -5032,14 +4966,14 @@ func TestCleanupNDPState(t *testing.T) {
test.cleanupFn(t, s)
// Collect invalidation events after having NDP state cleaned up.
- gotRouterEvents := make(map[ndpRouterEvent]int)
+ gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectRouterEvent()
+ ok, e := expectOffLinkRouteEvent()
if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
break
}
- gotRouterEvents[e]++
+ gotOffLinkRouteEvents[e]++
}
gotPrefixEvents := make(map[ndpPrefixEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
@@ -5066,14 +5000,14 @@ func TestCleanupNDPState(t *testing.T) {
t.FailNow()
}
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
}
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" {
+ t.Errorf("off-link route events mismatch (-want +got):\n%s", diff)
}
expectedPrefixEvents := map[ndpPrefixEvent]int{
{nicID: nicID1, prefix: subnet1, discovered: false}: 1,
@@ -5137,8 +5071,8 @@ func TestCleanupNDPState(t *testing.T) {
// cancelled when the NDP state was cleaned up).
clock.Advance(lifetimeSeconds * time.Second)
select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
+ case <-ndpDisp.offLinkRouteC:
+ t.Error("unexpected off-link route event")
default:
}
select {
@@ -5163,7 +5097,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
ndpDisp := ndpDispatcher{
dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
- rememberRouter: true,
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -5464,16 +5397,25 @@ func TestRouterSolicitation(t *testing.T) {
RandSource: &randSource,
})
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ opts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, &e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
}
if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining != 0 {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 378389db2..29d580e76 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -72,9 +72,15 @@ type nic struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained packetEndpointList are
- // not.
- packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
+ }
+
+ packetEPs struct {
+ mu sync.RWMutex
+
+ // eps is protected by the mutex, but the values contained in it are not.
+ //
+ // +checklocks:mu
+ eps map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
@@ -91,6 +97,8 @@ type packetEndpointList struct {
mu sync.RWMutex
// eps is protected by mu, but the contained PacketEndpoint values are not.
+ //
+ // +checklocks:mu
eps []PacketEndpoint
}
@@ -111,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) {
}
}
+func (p *packetEndpointList) len() int {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return len(p.eps)
+}
+
// forEach calls fn with each endpoints in p while holding the read lock on p.
func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
p.mu.RLock()
@@ -143,18 +157,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
}
nic.linkResQueue.init(nic)
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
+
+ nic.packetEPs.mu.Lock()
+ defer nic.packetEPs.mu.Unlock()
+
+ nic.packetEPs.eps = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0
- // Register supported packet and network endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = new(packetEndpointList)
- }
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = new(packetEndpointList)
-
netEP := netProto.NewEndpoint(nic, nic)
nic.networkEndpoints[netNum] = netEP
@@ -365,6 +377,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
+
if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil {
return err
}
@@ -383,6 +397,7 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
}
writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
@@ -501,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -512,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return &tcpip.ErrNotSupported{}
}
- addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -699,12 +714,9 @@ func (n *nic) isInGroup(addr tcpip.Address) bool {
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
if !enabled {
- n.mu.RUnlock()
-
n.stats.disabledRx.packets.Increment()
n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size()))
return
@@ -715,7 +727,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
- n.mu.RUnlock()
n.stats.unknownL3ProtocolRcvdPackets.Increment()
return
}
@@ -727,44 +738,87 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
- // Are any packet type sockets listening for this network protocol?
- protoEPs := n.mu.packetEPs[protocol]
- // Other packet type sockets that are listening for all protocols.
- anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
-
// Deliver to interested packet endpoints without holding NIC lock.
+ var packetEPPkt *PacketBuffer
deliverPacketEPs := func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketHost
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(),
+ })
+ // If a link header was populated in the original packet buffer, then
+ // populate it in the packet buffer we provide to packet endpoints as
+ // packet endpoints inspect link headers.
+ packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size())
+ packetEPPkt.PktType = tcpip.PacketHost
+ }
+
+ ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone())
}
- if protoEPs != nil {
+
+ n.packetEPs.mu.Lock()
+ // Are any packet type sockets listening for this network protocol?
+ protoEPs, protoEPsOK := n.packetEPs.eps[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs, anyEPsOK := n.packetEPs.eps[header.EthernetProtocolAll]
+ n.packetEPs.mu.Unlock()
+
+ if protoEPsOK {
protoEPs.forEach(deliverPacketEPs)
}
- if anyEPs != nil {
+ if anyEPsOK {
anyEPs.forEach(deliverPacketEPs)
}
networkEndpoint.HandlePacket(pkt)
}
-// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
-func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
+// deliverOutboundPacket delivers outgoing packets to interested endpoints.
+func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ n.packetEPs.mu.RLock()
+ defer n.packetEPs.mu.RUnlock()
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- eps := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
+ eps, ok := n.packetEPs.eps[header.EthernetProtocolAll]
+ if !ok {
+ return
+ }
+
+ local := n.LinkAddress()
+ var packetEPPkt *PacketBuffer
eps.forEach(func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketOutgoing
- // Add the link layer header as outgoing packets are intercepted
- // before the link layer header is created.
- n.LinkEndpoint.AddHeader(local, remote, protocol, p)
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: pkt.AvailableHeaderBytes(),
+ Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ })
+ // Add the link layer header as outgoing packets are intercepted before
+ // the link layer header is created and packet endpoints are interested
+ // in the link header.
+ n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt)
+ packetEPPkt.PktType = tcpip.PacketOutgoing
+ }
+
+ ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone())
})
}
@@ -779,17 +833,11 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
transProto := state.proto
- // Raw socket packets are delivered based solely on the transport
- // protocol number. We do not inspect the payload to ensure it's
- // validly formed.
- n.stack.demux.deliverRawPacket(protocol, pkt)
-
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
if pkt.TransportHeader().View().IsEmpty() {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
// ICMP packets may be longer, but until icmp.Parse is implemented, here
// we parse it using the minimum size.
@@ -878,6 +926,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo
}
}
+// DeliverRawPacket implements TransportDispatcher.
+func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
+ // For ICMPv4 only we validate the header length for compatibility with
+ // raw(7) ICMP_FILTER. The same check is made in Linux here:
+ // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189.
+ if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize {
+ return
+ }
+ n.stack.demux.deliverRawPacket(protocol, pkt)
+}
+
// ID implements NetworkInterface.
func (n *nic) ID() tcpip.NICID {
return n.id
@@ -912,12 +971,13 @@ func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigura
}
func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
- return &tcpip.ErrNotSupported{}
+ eps = new(packetEndpointList)
+ n.packetEPs.eps[netProto] = eps
}
eps.add(ep)
@@ -925,14 +985,17 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
}
func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
return
}
eps.remove(ep)
+ if eps.len() == 0 {
+ delete(n.packetEPs.eps, netProto)
+ }
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5cb342f78..c8ad93f29 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
-func (*testIPv6Protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 9192d8433..29c22bfd4 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -282,14 +282,12 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View {
return v
}
-// Clone makes a shallow copy of pk.
-//
-// Clone should be called in such cases so that no modifications is done to
-// underlying packet payload.
+// Clone makes a semi-deep copy of pk. The underlying packet payload is
+// shared. Hence, no modifications is done to underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- buf: pk.buf,
+ buf: pk.buf.Clone(),
reserved: pk.reserved,
pushed: pk.pushed,
consumed: pk.consumed,
@@ -321,14 +319,14 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
-// CloneToInbound makes a shallow copy of the packet buffer to be used as an
-// inbound packet.
+// CloneToInbound makes a semi-deep copy of the packet buffer (similar to
+// Clone) to be used as an inbound packet.
//
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
newPk := &PacketBuffer{
- buf: pk.buf,
+ buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index a8da34992..87b023445 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -123,6 +123,32 @@ func TestPacketHeaderPush(t *testing.T) {
}
}
+func TestPacketBufferClone(t *testing.T) {
+ data := concatViews(makeView(20), makeView(30), makeView(40))
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
+ })
+
+ bytesToDelete := 30
+ originalSize := data.Size()
+
+ clonedPks := []*PacketBuffer{
+ pk.Clone(),
+ pk.CloneToInbound(),
+ }
+ pk.Data().DeleteFront(bytesToDelete)
+ if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
+ t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
+ }
+ for _, clonedPk := range clonedPks {
+ if got := clonedPk.Data().Size(); got != originalSize {
+ t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
+ }
+ }
+}
+
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index a038389e0..31b3a554d 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -265,6 +265,11 @@ type TransportDispatcher interface {
//
// DeliverTransportError takes ownership of the packet buffer.
DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer)
+
+ // DeliverRawPacket delivers a packet to any subscribed raw sockets.
+ //
+ // DeliverRawPacket does NOT take ownership of the packet buffer.
+ DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -313,8 +318,7 @@ type PrimaryEndpointBehavior int
const (
// CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
+ // endpoint for new connections with no local address.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
// FirstPrimaryEndpoint indicates the endpoint should be the first
@@ -327,6 +331,19 @@ const (
NeverPrimaryEndpoint
)
+func (peb PrimaryEndpointBehavior) String() string {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return "CanBePrimaryEndpoint"
+ case FirstPrimaryEndpoint:
+ return "FirstPrimaryEndpoint"
+ case NeverPrimaryEndpoint:
+ return "NeverPrimaryEndpoint"
+ default:
+ panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb))
+ }
+}
+
// AddressConfigType is the method used to add an address.
type AddressConfigType int
@@ -346,6 +363,14 @@ const (
AddressConfigSlaacTemp
)
+// AddressProperties contains additional properties that can be configured when
+// adding an address.
+type AddressProperties struct {
+ PEB PrimaryEndpointBehavior
+ ConfigType AddressConfigType
+ Deprecated bool
+}
+
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
@@ -452,7 +477,7 @@ type AddressableEndpoint interface {
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
@@ -680,9 +705,6 @@ type NetworkProtocol interface {
// than this targeted at this protocol.
MinimumPacketSize() int
- // DefaultPrefixLen returns the protocol's default prefix length.
- DefaultPrefixLen() int
-
// ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
@@ -728,16 +750,6 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
-
- // DeliverOutboundPacket is called by link layer when a packet is being
- // sent out.
- //
- // pkt.LinkHeader may or may not be set before calling
- // DeliverOutboundPacket. Some packets do not have link headers (e.g.
- // packets sent via loopback), and won't have the field set.
- //
- // DeliverOutboundPacket takes ownership of pkt.
- DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -841,6 +853,14 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link.
+ //
+ // If the link-layer has its own header, the payload must already include the
+ // header.
+ //
+ // WriteRawPacket takes ownership of the packet.
+ WriteRawPacket(*PacketBuffer) tcpip.Error
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 40d277312..98867a828 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -72,7 +72,8 @@ type Stack struct {
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
- rawFactory RawFactory
+ rawFactory RawFactory
+ packetEndpointWriteSupported bool
demux *transportDemuxer
@@ -108,7 +109,7 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
- // TODO(gvisor.dev/issue/170): S/R this field.
+ // TODO(gvisor.dev/issue/4595): S/R this field.
tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -119,8 +120,7 @@ type Stack struct {
// by the stack.
icmpRateLimiter *ICMPRateLimiter
- // seed is a one-time random value initialized at stack startup
- // and is used to seed the TCP port picking on active connections
+ // seed is a one-time random value initialized at stack startup.
//
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
@@ -161,6 +161,10 @@ type Stack struct {
// This is required to prevent potential ACK loops.
// Setting this to 0 will disable all rate limiting.
tcpInvalidRateLimit time.Duration
+
+ // tsOffsetSecret is the secret key for generating timestamp offsets
+ // initialized at stack startup.
+ tsOffsetSecret uint32
}
// UniqueID is an abstract generator of unique identifiers.
@@ -215,6 +219,10 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
+ // AllowPacketEndpointWrite determines if packet endpoints support write
+ // operations.
+ AllowPacketEndpointWrite bool
+
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
// returned by the stack secure RNG.
@@ -356,23 +364,24 @@ func New(opts Options) *Stack {
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: seed,
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: randomGenerator,
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ packetEndpointWriteSupported: opts.AllowPacketEndpointWrite,
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: seed,
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: randomGenerator,
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -384,6 +393,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
tcpInvalidRateLimit: defaultTCPInvalidRateLimit,
+ tsOffsetSecret: randomGenerator.Uint32(),
}
// Add specified network protocols.
@@ -780,6 +790,9 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error {
if !ok {
return &tcpip.ErrUnknownNICID{}
}
+ if nic.IsLoopback() {
+ return &tcpip.ErrNotSupported{}
+ }
delete(s.nics, id)
// Remove routes in-place. n tracks the number of routes written.
@@ -903,46 +916,9 @@ type NICStateFlags struct {
Loopback bool
}
-// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
- return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
-// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
- ap := tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: addr,
- }
- return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
-}
-
-// AddProtocolAddress adds a new network-layer protocol address to the
-// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
- return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify
-// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return &tcpip.ErrUnknownProtocol{}
- }
- return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb)
-}
-
-// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
-// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+// AddProtocolAddress adds an address to the specified NIC, possibly with extra
+// properties.
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -951,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return &tcpip.ErrUnknownNICID{}
}
- return nic.addAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, properties)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1646,9 +1622,27 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress,
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
Data: payload,
})
+ pkt.NetworkProtocolNumber = netProto
return nic.WritePacketToRemote(remote, netProto, pkt)
}
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
+ }
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ Data: payload,
+ })
+ pkt.NetworkProtocolNumber = proto
+ return nic.WriteRawPacket(pkt)
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1816,8 +1810,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol
return nic.setNUDConfigs(proto, c)
}
-// Seed returns a 32 bit value that can be used as a seed value for port
-// picking, ISN generation etc.
+// Seed returns a 32 bit value that can be used as a seed value.
//
// NOTE: The seed is generated once during stack initialization only.
func (s *Stack) Seed() uint32 {
@@ -1872,9 +1865,8 @@ const (
// ParsePacketBufferTransport parses the provided packet buffer's transport
// header.
func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
return ParsedOK
}
@@ -1942,3 +1934,9 @@ func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProto
return false
}
+
+// PacketEndpointWriteSupported returns true iff packet endpoints support write
+// operations.
+func (s *Stack) PacketEndpointWriteSupported() bool {
+ return s.packetEndpointWriteSupported
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 21951d05a..cd4137794 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -234,10 +234,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (*fakeNetworkProtocol) DefaultPrefixLen() int {
- return fakeDefaultPrefixLen
-}
-
func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
return f.packetCount[int(intfAddr)%len(f.packetCount)]
}
@@ -349,12 +345,26 @@ func TestNetworkReceive(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -517,8 +527,15 @@ func TestNetworkSend(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Make sure that the link-layer endpoint received the outbound packet.
@@ -538,12 +555,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -551,12 +582,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -719,38 +764,59 @@ func TestRemoveUnknownNIC(t *testing.T) {
}
func TestRemoveNIC(t *testing.T) {
- const nicID = 1
+ for _, tt := range []struct {
+ name string
+ linkep stack.LinkEndpoint
+ expectErr tcpip.Error
+ }{
+ {
+ name: "loopback",
+ linkep: loopback.New(),
+ expectErr: &tcpip.ErrNotSupported{},
+ },
+ {
+ name: "channel",
+ linkep: channel.New(0, defaultMTU, ""),
+ expectErr: nil,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
- e := linkEPWithMockedAttach{
- LinkEndpoint: loopback.New(),
- }
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: tt.linkep,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // NIC should be present in NICInfo and attached to a NetworkDispatcher.
- allNICInfo := s.NICInfo()
- if _, ok := allNICInfo[nicID]; !ok {
- t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
- }
- if !e.isAttached() {
- t.Fatal("link endpoint not attached to a network dispatcher")
- }
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
+ }
- // Removing a NIC should remove it from NICInfo and e should be detached from
- // the NetworkDispatcher.
- if err := s.RemoveNIC(nicID); err != nil {
- t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
- }
- if nicInfo, ok := s.NICInfo()[nicID]; ok {
- t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
- }
- if e.isAttached() {
- t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want {
+ t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want)
+ }
+ if tt.expectErr == nil {
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
+ }
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ }
+ }
+ })
}
}
@@ -791,8 +857,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
ep2 := channel.New(1, defaultMTU, "")
@@ -800,8 +873,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr2,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
// Set a route table that sends all packets with odd destination
@@ -957,12 +1037,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -970,12 +1064,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -1037,8 +1145,15 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1087,8 +1202,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1221,8 +1343,15 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1249,8 +1378,8 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1289,8 +1418,8 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1432,8 +1561,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
@@ -1489,8 +1625,15 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
@@ -1612,8 +1755,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
- if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
+ if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
@@ -1657,13 +1800,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
t.Fatalf("CreateNIC failed: %s", err)
}
nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1705,7 +1848,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1},
},
rt...,
)
@@ -1787,8 +1930,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: anyAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
@@ -1865,22 +2015,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Add an address and in case of a primary one include a
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
+ properties := stack.AddressProperties{PEB: behavior}
if behavior == stack.CanBePrimaryEndpoint {
protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: addrLen * 8,
- },
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: address.WithPrefix(),
}
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
}
}
@@ -1975,8 +2130,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
PrefixLen: tc.prefixLen,
},
}
- if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
- t.Fatal("AddProtocolAddress failed:", err)
+ if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -2026,33 +2181,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
}
-func TestAddAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
- for _, addrLen := range []int{4, 16} {
- address := addrGen.next(addrLen)
- if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
@@ -2063,96 +2191,43 @@ func TestAddProtocolAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- var addrGen addressGenerator
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
addrLenRange := []int{4, 16}
behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp}
+ deprecatedRange := []bool{false, true}
+ wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange))
var addrGen addressGenerator
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- for _, behavior := range behaviorRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ for _, configType := range configTypeRange {
+ for _, deprecated := range deprecatedRange {
+ address := addrGen.next(addrLen)
+ properties := stack.AddressProperties{
+ PEB: behavior,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err)
+ }
+ wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
+ })
}
- expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
+ verifyAddresses(t, wantAddresses, gotAddresses)
}
func TestCreateNICWithOptions(t *testing.T) {
@@ -2269,8 +2344,15 @@ func TestNICStats(t *testing.T) {
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed: ", err)
}
- if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nic.addr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err)
}
{
@@ -2714,8 +2796,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// be returned by a call to GetMainNICAddress;
// else, it should.
const address1 = tcpip.Address("\x01")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ properties := stack.AddressProperties{PEB: pi}
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err)
}
addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
if err != nil {
@@ -2764,16 +2854,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
const address3 = tcpip.Address("\x03")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
-
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address3,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: ps}
+ if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
@@ -3075,8 +3180,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
for _, a := range test.nicAddrs {
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
- t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: a.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -3182,8 +3291,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// The NIC should have joined addr1's solicited node multicast address.
@@ -3338,8 +3451,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
PrefixLen: 128,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
// Address should be in the list of all addresses.
@@ -3666,8 +3779,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
@@ -3729,8 +3842,8 @@ func TestResolveWith(t *testing.T) {
PrefixLen: 24,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
@@ -3771,8 +3884,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3860,8 +3980,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -3969,44 +4089,44 @@ func TestFindRouteWithForwarding(t *testing.T) {
)
type netCfg struct {
- proto tcpip.NetworkProtocolNumber
- factory stack.NetworkProtocolFactory
- nic1Addr tcpip.Address
- nic2Addr tcpip.Address
- remoteAddr tcpip.Address
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1AddrWithPrefix tcpip.AddressWithPrefix
+ nic2AddrWithPrefix tcpip.AddressWithPrefix
+ remoteAddr tcpip.Address
}
fakeNetCfg := netCfg{
- proto: fakeNetNumber,
- factory: fakeNetFactory,
- nic1Addr: nic1Addr,
- nic2Addr: nic2Addr,
- remoteAddr: remoteAddr,
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen},
+ nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen},
+ remoteAddr: remoteAddr,
}
globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: llAddr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: globalIPv6Addr1,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: llAddr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: globalIPv6Addr1,
}
ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: llAddr1,
- remoteAddr: llAddr2,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: llAddr1.WithPrefix(),
+ remoteAddr: llAddr2,
}
ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
}
tests := []struct {
@@ -4015,8 +4135,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg netCfg
forwardingEnabled bool
- addrNIC tcpip.NICID
- localAddr tcpip.Address
+ addrNIC tcpip.NICID
+ localAddrWithPrefix tcpip.AddressWithPrefix
findRouteErr tcpip.Error
dependentOnForwarding bool
@@ -4026,7 +4146,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4035,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4044,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4053,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4062,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4071,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4080,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4089,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4097,7 +4217,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4105,7 +4225,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4113,7 +4233,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4121,7 +4241,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: true,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4145,7 +4265,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on different NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4153,7 +4273,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4161,7 +4281,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4169,7 +4289,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4177,7 +4297,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4185,7 +4305,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4193,7 +4313,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4201,7 +4321,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4209,7 +4329,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4217,7 +4337,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4225,7 +4345,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4247,12 +4367,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
}
- if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic1AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
- if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic2AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
@@ -4261,20 +4389,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
- r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
- t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
return
}
- if r.LocalAddress() != test.localAddr {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr)
+ if r.LocalAddress() != test.localAddrWithPrefix.Address {
+ t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address)
}
if r.RemoteAddress() != test.netCfg.remoteAddr {
t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr)
@@ -4297,8 +4425,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if !ok {
t.Fatal("packet not sent through ep2")
}
- if pkt.Route.LocalAddress != test.localAddr {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address)
}
if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index e90c1a770..dc7289441 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -380,15 +381,18 @@ type TCPSndBufState struct {
// SndClosed indicates that the endpoint has been closed for sends.
SndClosed bool
- // SndBufInQueue is the number of bytes in the send queue.
- SndBufInQueue seqnum.Size
-
// PacketTooBigCount is used to notify the main protocol routine how
// many times a "packet too big" control packet is received.
PacketTooBigCount int
// SndMTU is the smallest MTU seen in the control packets received.
SndMTU int
+
+ // AutoTuneSndBufDisabled indicates that the auto tuning of send buffer
+ // is disabled.
+ //
+ // Must be accessed using atomic operations.
+ AutoTuneSndBufDisabled uint32
}
// TCPEndpointStateInner contains the members of TCPEndpointState used directly
@@ -399,7 +403,7 @@ type TCPSndBufState struct {
type TCPEndpointStateInner struct {
// TSOffset is a randomized offset added to the value of the TSVal
// field in the timestamp option.
- TSOffset uint32
+ TSOffset tcp.TSOffset
// SACKPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 8a8454a6a..542d9257c 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -31,11 +32,13 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
- // mu protects all fields of the transportEndpoints.
- mu sync.RWMutex
+ mu sync.RWMutex
+ // +checklocks:mu
endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
+ //
+ // +checklocks:mu
rawEndpoints []RawTransportEndpoint
}
@@ -68,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@@ -109,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -121,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -132,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo
}
type endpointsByNIC struct {
- mu sync.RWMutex
- endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
+
+ mu sync.RWMutex
+ // +checklocks:mu
+ endpoints map[tcpip.NICID]*multiPortEndpoint
}
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
@@ -170,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
- transEP := selectEndpoint(id, mpep, epsByNIC.seed)
+ transEP := mpep.selectEndpoint(id, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
@@ -199,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt)
+ mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
@@ -215,10 +220,17 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
netProto: netProto,
transProto: transProto,
}
- epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- return multiPortEp.singleRegisterEndpoint(t, flags)
+ if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil {
+ return err
+ }
+ // Only add this newly created multiportEndpoint if the singleRegisterEndpoint
+ // succeeded.
+ if !ok {
+ epsByNIC.endpoints[bindToDevice] = multiPortEp
+ }
+ return nil
}
func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
@@ -325,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
demux *transportDemuxer
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
+ flags ports.FlagCounter
+
+ mu sync.RWMutex `state:"nosave"`
// endpoints stores the transport endpoints in the order in which they
// were bound. This is required for UDP SO_REUSEADDR.
+ //
+ // +checklocks:mu
endpoints []TransportEndpoint
- flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -354,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpoints) == 1 {
- return mpep.endpoints[0]
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ if len(ep.endpoints) == 1 {
+ return ep.endpoints[0]
}
- if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
- return mpep.endpoints[len(mpep.endpoints)-1]
+ if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
+ return ep.endpoints[len(ep.endpoints)-1]
}
payload := []byte{
@@ -376,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
- return mpep.endpoints[idx]
+ idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
+ return ep.endpoints[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
@@ -405,7 +423,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
-
bits := flags.Bits() & ports.MultiBindFlagMask
if len(ep.endpoints) != 0 {
@@ -468,17 +485,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
-
epsByNIC, ok := eps.endpoints[id]
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: d.stack.Seed(),
+ seed: d.stack.seed,
}
+ }
+ if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
+ return err
+ }
+ // Only add this newly created epsByNIC if registerEndpoint succeeded.
+ if !ok {
eps.endpoints[id] = epsByNIC
}
-
- return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
+ return nil
}
func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
@@ -646,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
}
}
- ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ ep := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
return ep
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 0972c94de..cd3a8c25a 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -35,7 +35,7 @@ import (
const (
testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
@@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: testDstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
}
}
@@ -203,6 +211,56 @@ func TestTransportDemuxerRegister(t *testing.T) {
}
}
+func TestTransportDemuxerRegisterMultiple(t *testing.T) {
+ type test struct {
+ flags ports.Flags
+ want tcpip.Error
+ }
+ for _, subtest := range []struct {
+ name string
+ tests []test
+ }{
+ {"zeroFlags", []test{
+ {ports.Flags{}, nil},
+ {ports.Flags{}, &tcpip.ErrPortInUse{}},
+ }},
+ {"multibindFlags", []test{
+ // Allow multiple registrations same TransportEndpointID with multibind flags.
+ {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
+ {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
+ // Disallow registration w/same ID for a non-multibindflag.
+ {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}},
+ }},
+ } {
+ t.Run(subtest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ var eps []tcpip.Endpoint
+ for idx, test := range subtest.tests {
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ eps = append(eps, ep)
+ tEP, ok := ep.(stack.TransportEndpoint)
+ if !ok {
+ t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
+ }
+ id := stack.TransportEndpointID{LocalPort: 1}
+ if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want {
+ t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want)
+ }
+ }
+ for _, ep := range eps {
+ ep.Close()
+ }
+ })
+ }
+}
+
// TestBindToDeviceDistribution injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
func TestBindToDeviceDistribution(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 839178809..655931715 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -357,8 +357,15 @@ func TestTransportReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -428,8 +435,15 @@ func TestTransportControlReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -497,8 +511,15 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 91622fa4c..a9ce148b9 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -19,7 +19,7 @@
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
-// adding network addresses (via calls to Stack.AddAddress()), and
+// adding network addresses (via calls to Stack.AddProtocolAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
@@ -465,11 +465,11 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns UID of the packet.
- UID() uint32
+ // UID returns KUID of the packet.
+ KUID() uint32
- // GID returns GID of the packet.
- GID() uint32
+ // GID returns KGID of the packet.
+ KGID() uint32
}
// ReadOptions contains options for Endpoint.Read.
@@ -1845,6 +1845,10 @@ type TCPStats struct {
// FailedPortReservations is the number of times TCP failed to reserve
// a port.
FailedPortReservations *StatCounter
+
+ // SegmentsAckedWithDSACK is the number of segments acknowledged with
+ // DSACK.
+ SegmentsAckedWithDSACK *StatCounter
}
// UDPStats collects UDP-specific stats.
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 92fa6257d..6e1d4720d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) {
addr: utils.RouterNIC2IPv6Addr,
},
} {
- if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index f9ab7d0af..28b49c6be 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -49,10 +49,10 @@ const (
nicName = "nic1"
anotherNicName = "nic2"
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = "\x0a\x00\x00\x01"
- dstAddrV4 = "\x0a\x00\x00\x02"
- srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
+ dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
+ srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
payloadSize = 20
)
@@ -66,8 +66,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: dstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -82,8 +86,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: dstAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -601,11 +609,19 @@ func TestIPTableWritePackets(t *testing.T) {
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: srcAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: srcAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -856,11 +872,19 @@ func TestForwardingHook(t *testing.T) {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1037,22 +1061,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
}
e2 := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 27caa0c28..95ddd8ec3 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
@@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.incomingAddr,
}
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
}
// Set up endpoint through which we will attempt to forward packets.
@@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.outgoingAddr,
}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 87d36e1dd..f33223e79 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -44,20 +44,17 @@ type ndpDispatcher struct{}
func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
- return false
+func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) {
}
-func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {}
-func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
-func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return true
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
@@ -198,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -293,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -434,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -696,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
+ }
+ v6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
}
if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ipv4Loopback,
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if test.forwarding {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 2d0a6e6a7..7753e7d6e 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
@@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
// Set the route table so that UDP can find a NIC that is
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index ac3c703d4..422eb8408 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) {
// request/reply packets.
icmpDataOffset = 8
)
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
+ ipv4Loopback := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse4("127.0.0.1"),
+ PrefixLen: 8,
+ }
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
@@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
icmpBuf func(*testing.T) buffer.View
expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
@@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback,
+ localAddr: header.IPv6Loopback.WithPrefix(),
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
@@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr.Address,
+ localAddr: utils.Ipv4Addr,
icmpBuf: ipv4ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr.Address,
+ localAddr: utils.Ipv6Addr,
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ if len(test.localAddr.Address) != 0 {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.netProto,
+ AddressWithPrefix: test.localAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) {
}
defer ep.Close()
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
if err := ep.Connect(connAddr); err != test.expectedConnectErr {
t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
}
@@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) {
if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
- if rr.RemoteAddr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ if rr.RemoteAddr.Addr != test.localAddr.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
}
test.checkLinkEndpoint(t, e)
@@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) {
}
if subTest.addAddress {
- if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
}
- if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
}
}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 2e6ae55ea..947bcc7b1 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -231,29 +231,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD
new file mode 100644
index 000000000..af332ed91
--- /dev/null
+++ b/pkg/tcpip/transport/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "transport",
+ srcs = [
+ "datagram.go",
+ "transport.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip"],
+)
diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go
new file mode 100644
index 000000000..dfce72c69
--- /dev/null
+++ b/pkg/tcpip/transport/datagram.go
@@ -0,0 +1,49 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// DatagramEndpointState is the state of a datagram-based endpoint.
+type DatagramEndpointState tcpip.EndpointState
+
+// The states a datagram-based endpoint may be in.
+const (
+ _ DatagramEndpointState = iota
+ DatagramEndpointStateInitial
+ DatagramEndpointStateBound
+ DatagramEndpointStateConnected
+ DatagramEndpointStateClosed
+)
+
+// String implements fmt.Stringer.
+func (s DatagramEndpointState) String() string {
+ switch s {
+ case DatagramEndpointStateInitial:
+ return "INITIAL"
+ case DatagramEndpointStateBound:
+ return "BOUND"
+ case DatagramEndpointStateConnected:
+ return "CONNECTED"
+ case DatagramEndpointStateClosed:
+ return "CLOSED"
+ default:
+ panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s))
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index fb77febcf..1e519085d 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -213,6 +213,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
+// +checklocks:e.mu
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.state {
case stateInitial:
@@ -229,10 +230,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
@@ -330,6 +329,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
route = r
}
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
return 0, &tcpip.ErrBadBuffer{}
@@ -688,9 +688,20 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
return nil
}
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast ||
+ header.IsV4MulticastAddress(addr) ||
+ header.IsV6MulticastAddress(addr) ||
+ e.stack.IsSubnetBroadcast(nicID, e.NetProto, addr)
+}
+
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
+ if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -758,8 +769,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -767,8 +776,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
index cc950cbde..729f50e9a 100644
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s
t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.AddRoute(tcpip.Route{
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 47f7dd1cb..fa82affc1 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -123,8 +123,6 @@ func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
- //
// Right now, the Parse() method is tied to enabled protocols passed into
// stack.New. This works for UDP and TCP, but we handle ICMP traffic even
// when netstack users don't pass ICMP as a supported protocol.
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
new file mode 100644
index 000000000..b1edce39b
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -0,0 +1,45 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "network",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/transport/raw:__pkg__",
+ "//pkg/tcpip/transport/udp:__pkg__",
+ ],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ ],
+)
+
+go_test(
+ name = "network_test",
+ size = "small",
+ srcs = ["endpoint_test.go"],
+ deps = [
+ ":network",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
new file mode 100644
index 000000000..e3094f59f
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -0,0 +1,811 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package network provides facilities to support tcpip.Endpoints that operate
+// at the network layer or above.
+package network
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Endpoint is a datagram-based endpoint. It only supports sending datagrams to
+// a peer.
+//
+// +stateify savable
+type Endpoint struct {
+ // The following fields must only be set once then never changed.
+ stack *stack.Stack `state:"manual"`
+ ops *tcpip.SocketOptions
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ wasBound bool
+ // owner is the owner of transmitted packets.
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
+ multicastMemberships map[multicastMembership]struct{}
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ ttl uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastTTL uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastAddr tcpip.Address
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastNICID tcpip.NICID
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
+
+ // Lock ordering: mu > infoMu.
+ infoMu sync.RWMutex `state:"nosave"`
+ // info has a dedicated mutex so that we can avoid lock ordering violations
+ // when reading the endpoint's info. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling Info()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setInfo.
+ //
+ // +checklocks:infoMu
+ info stack.TransportEndpointInfo
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be accessed with atomics so that we can avoid lock ordering
+ // violations when reading the state. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling State()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setEndpointState.
+ //
+ // +checkatomics
+ state uint32
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+// Init initializes the endpoint.
+func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
+ }
+
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ *e = Endpoint{
+ stack: s,
+ ops: ops,
+ netProto: netProto,
+ transProto: transProto,
+
+ info: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
+ effectiveNetProto: netProto,
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+}
+
+// NetProto returns the network protocol the endpoint was initialized with.
+func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
+ return e.netProto
+}
+
+// setEndpointState sets the state of the endpoint.
+//
+// e.mu must be held to synchronize changes to state with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
+ atomic.StoreUint32(&e.state, uint32(state))
+}
+
+// State returns the state of the endpoint.
+func (e *Endpoint) State() transport.DatagramEndpointState {
+ return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+}
+
+// Close cleans the endpoint's resources and leaves the endpoint in a closed
+// state.
+func (e *Endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return
+ }
+
+ for mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ if e.connectedRoute != nil {
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+ }
+
+ e.setEndpointState(transport.DatagramEndpointStateClosed)
+}
+
+// SetOwner sets the owner of transmitted packets.
+func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.owner = owner
+}
+
+func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 {
+ if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
+ return multicastTTL
+ }
+
+ if ttl == 0 {
+ return route.DefaultTTL()
+ }
+
+ return ttl
+}
+
+// WriteContext holds the context for a write.
+type WriteContext struct {
+ transProto tcpip.TransportProtocolNumber
+ route *stack.Route
+ ttl uint8
+ tos uint8
+ owner tcpip.PacketOwner
+}
+
+// Release releases held resources.
+func (c *WriteContext) Release() {
+ c.route.Release()
+ *c = WriteContext{}
+}
+
+// WritePacketInfo is the properties of a packet that may be written.
+type WritePacketInfo struct {
+ NetProto tcpip.NetworkProtocolNumber
+ LocalAddress, RemoteAddress tcpip.Address
+ MaxHeaderLength uint16
+ RequiresTXTransportChecksum bool
+}
+
+// PacketInfo returns the properties of a packet that will be written.
+func (c *WriteContext) PacketInfo() WritePacketInfo {
+ return WritePacketInfo{
+ NetProto: c.route.NetProto(),
+ LocalAddress: c.route.LocalAddress(),
+ RemoteAddress: c.route.RemoteAddress(),
+ MaxHeaderLength: c.route.MaxHeaderLength(),
+ RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(),
+ }
+}
+
+// WritePacket attempts to write the packet.
+func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
+ pkt.Owner = c.owner
+
+ if headerIncluded {
+ return c.route.WriteHeaderIncludedPacket(pkt)
+ }
+
+ return c.route.WritePacket(stack.NetworkHeaderParams{
+ Protocol: c.transProto,
+ TTL: c.ttl,
+ TOS: c.tos,
+ }, pkt)
+}
+
+// AcquireContextForWrite acquires a WriteContext.
+func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
+ }
+
+ if e.writeShutdown {
+ return WriteContext{}, &tcpip.ErrClosedForSend{}
+ }
+
+ route := e.connectedRoute
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return WriteContext{}, &tcpip.ErrDestinationRequired{}
+ }
+
+ route.Acquire()
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicID := opts.To.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
+ info := e.Info()
+ if info.BindNICID != 0 {
+ if nicID != 0 && nicID != info.BindNICID {
+ return WriteContext{}, &tcpip.ErrNoRoute{}
+ }
+
+ nicID = info.BindNICID
+ }
+ if nicID == 0 {
+ nicID = info.RegisterNICID
+ }
+
+ dst, netProto, err := e.checkV4Mapped(*opts.To)
+ if err != nil {
+ return WriteContext{}, err
+ }
+
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
+ if err != nil {
+ return WriteContext{}, err
+ }
+ }
+
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
+ route.Release()
+ return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
+ }
+
+ var tos uint8
+ switch netProto := route.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ tos = e.ipv4TOS
+ case header.IPv6ProtocolNumber:
+ tos = e.ipv6TClass
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ return WriteContext{
+ transProto: e.transProto,
+ route: route,
+ ttl: calculateTTL(route, e.ttl, e.multicastTTL),
+ tos: tos,
+ owner: e.owner,
+ }, nil
+}
+
+// Disconnect disconnects the endpoint from its peer.
+func (e *Endpoint) Disconnect() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return
+ }
+
+ info := e.Info()
+ // Exclude ephemerally bound endpoints.
+ if e.wasBound {
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: info.BindAddr,
+ }
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ } else {
+ info.ID = stack.TransportEndpointID{}
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+ }
+ e.setInfo(info)
+
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+}
+
+// connectRouteRLocked establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+ localAddr := e.Info().ID.LocalAddress
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicID == 0 {
+ nicID = e.multicastNICID
+ }
+ if localAddr == "" && nicID == 0 {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
+ if err != nil {
+ return nil, 0, err
+ }
+ return r, nicID, nil
+}
+
+// Connect connects the endpoint to the address.
+func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ return nil
+ })
+}
+
+// ConnectAndThen connects the endpoint to the address and then calls the
+// provided function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the network protocol used to connect to the peer
+// and the source and destination addresses that will be used to send traffic to
+// the peer.
+func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ info := e.Info()
+ nicID := addr.NIC
+ switch e.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ if info.BindNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != info.BindNICID {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ nicID = info.BindNICID
+ default:
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4Mapped(addr)
+ if err != nil {
+ return err
+ }
+
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
+ if err != nil {
+ return err
+ }
+
+ id := stack.TransportEndpointID{
+ LocalAddress: info.ID.LocalAddress,
+ RemoteAddress: r.RemoteAddress(),
+ }
+ if e.State() == transport.DatagramEndpointStateInitial {
+ id.LocalAddress = r.LocalAddress()
+ }
+
+ if err := f(r.NetProto(), info.ID, id); err != nil {
+ return err
+ }
+
+ if e.connectedRoute != nil {
+ // If the endpoint was previously connected then release any previous route.
+ e.connectedRoute.Release()
+ }
+ e.connectedRoute = r
+ info.ID = id
+ info.RegisterNICID = nicID
+ e.setInfo(info)
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateConnected)
+ return nil
+}
+
+// Shutdown shutsdown the endpoint.
+func (e *Endpoint) Shutdown() tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ e.writeShutdown = true
+ return nil
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
+
+// checkV4MappedRLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ info := e.Info()
+ unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
+}
+
+// Bind binds the endpoint to the address.
+func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
+ return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error {
+ return nil
+ })
+}
+
+// BindAndThen binds the endpoint to the address and then calls the provided
+// function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the bound network protocol and address.
+func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.State() != transport.DatagramEndpointStateInitial {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4Mapped(addr)
+ if err != nil {
+ return err
+ }
+
+ nicID := addr.NIC
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
+ nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr)
+ if nicID == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if err := f(netProto, addr.Addr); err != nil {
+ return err
+ }
+
+ e.wasBound = true
+
+ info := e.Info()
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: addr.Addr,
+ }
+ info.BindNICID = addr.NIC
+ info.RegisterNICID = nicID
+ info.BindAddr = addr.Addr
+ e.setInfo(info)
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ return nil
+}
+
+// WasBound returns true iff the endpoint was ever bound.
+func (e *Endpoint) WasBound() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.wasBound
+}
+
+// GetLocalAddress returns the address that the endpoint is bound to.
+func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ info := e.Info()
+ addr := info.BindAddr
+ if e.State() == transport.DatagramEndpointStateConnected {
+ addr = e.connectedRoute.LocalAddress()
+ }
+
+ return tcpip.FullAddress{
+ NIC: info.RegisterNICID,
+ Addr: addr,
+ }
+}
+
+// GetRemoteAddress returns the address that the endpoint is connected to.
+func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return tcpip.FullAddress{}, false
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.connectedRoute.RemoteAddress(),
+ NIC: e.Info().RegisterNICID,
+ }, true
+}
+
+// SetSockOptInt sets the socket option.
+func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.ipv4TOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.ipv6TClass = uint8(v)
+ e.mu.Unlock()
+ }
+
+ return nil
+}
+
+// GetSockOptInt returns the socket option.
+func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.ipv4TOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.ipv6TClass)
+ e.mu.RUnlock()
+ return v, nil
+
+ default:
+ return -1, &tcpip.ErrUnknownProtocolOption{}
+ }
+}
+
+// SetSockOpt sets the socket option.
+func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ fa, netProto, err := e.checkV4Mapped(fa)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case *tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return &tcpip.ErrPortInUse{}
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToInsert] = struct{}{}
+
+ case *tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ delete(e.multicastMemberships, memToRemove)
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt returns the socket option.
+func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ NIC: e.multicastNICID,
+ InterfaceAddr: e.multicastAddr,
+ }
+ e.mu.Unlock()
+
+ default:
+ return &tcpip.ErrUnknownProtocolOption{}
+ }
+ return nil
+}
+
+// Info returns a copy of the endpoint info.
+func (e *Endpoint) Info() stack.TransportEndpointInfo {
+ e.infoMu.RLock()
+ defer e.infoMu.RUnlock()
+ return e.info
+}
+
+// setInfo sets the endpoint's info.
+//
+// e.mu must be held to synchronize changes to info with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
+ e.infoMu.Lock()
+ defer e.infoMu.Unlock()
+ e.info = info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
new file mode 100644
index 000000000..68bd1fbf6
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -0,0 +1,58 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *Endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.stack = s
+
+ for m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err))
+ }
+ }
+
+ info := e.Info()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound:
+ if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress))
+ }
+ }
+ case transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ multicastLoop := e.ops.GetMulticastLoop()
+ e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ if err != nil {
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
new file mode 100644
index 000000000..f263a9ea2
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -0,0 +1,318 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+)
+
+func TestEndpointStateTransitions(t *testing.T) {
+ const nicID = 1
+
+ data := buffer.View([]byte{1, 2, 4, 5})
+ v4Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(ipv4NICAddr),
+ checker.DstAddr(ipv4RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ v6Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(ipv6NICAddr),
+ checker.DstAddr(ipv6RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ expectedMaxHeaderLength uint16
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLocalAddr tcpip.Address
+ bindAddr tcpip.Address
+ expectedBoundAddr tcpip.Address
+ remoteAddr tcpip.Address
+ expectedRemoteAddr tcpip.Address
+ checker func(*testing.T, buffer.View)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: header.IPv4AllSystems,
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: ipv4RemoteAddr,
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv6FixedHeaderSize,
+ expectedNetProto: ipv6.ProtocolNumber,
+ expectedLocalAddr: ipv6NICAddr,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ expectedBoundAddr: header.IPv6AllNodesMulticastAddress,
+ remoteAddr: ipv6RemoteAddr,
+ expectedRemoteAddr: ipv6RemoteAddr,
+ checker: v6Checker,
+ },
+ {
+ name: "IPv4-mapped-IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: testutil.MustParse6("::ffff:e000:0001"),
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: testutil.MustParse6("::ffff:0607:0809"),
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ })
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateInitial {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateBound {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); connected {
+ t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr)
+ }
+
+ connectAddr := tcpip.FullAddress{Addr: test.remoteAddr}
+ if err := ep.Connect(connectAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", connectAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateConnected {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); !connected {
+ t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr)
+ } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" {
+ t.Errorf("remote address mismatch (-want +got):\n%s", diff)
+ }
+
+ ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("ep.AcquireContexForWrite({}): %s", err)
+ }
+ defer ctx.Release()
+ info := ctx.PacketInfo()
+ if diff := cmp.Diff(network.WritePacketInfo{
+ NetProto: test.expectedNetProto,
+ LocalAddress: test.expectedLocalAddr,
+ RemoteAddress: test.expectedRemoteAddr,
+ MaxHeaderLength: test.expectedMaxHeaderLength,
+ RequiresTXTransportChecksum: true,
+ }, info); diff != "" {
+ t.Errorf("write packet info mismatch (-want +got):\n%s", diff)
+ }
+ if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(info.MaxHeaderLength),
+ Data: data.ToVectorisedView(),
+ }), false /* headerIncluded */); err != nil {
+ t.Fatalf("ctx.WritePacket(_, false): %s", err)
+ }
+ if pkt, ok := e.Read(); !ok {
+ t.Fatalf("expected packet to be read from link endpoint")
+ } else {
+ test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ }
+
+ ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateClosed {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed)
+ }
+ })
+ }
+}
+
+func TestBindNICID(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ bindAddr tcpip.Address
+ unicast bool
+ }{
+ {
+ name: "IPv4 multicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: header.IPv4AllSystems,
+ unicast: false,
+ },
+ {
+ name: "IPv6 multicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ unicast: false,
+ },
+ {
+ name: "IPv4 unicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: ipv4NICAddr,
+ unicast: true,
+ },
+ {
+ name: "IPv6 unicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: ipv6NICAddr,
+ unicast: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, testBindNICID := range []tcpip.NICID{0, nicID} {
+ t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
+ }
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if ep.WasBound() {
+ t.Fatal("got ep.WasBound() = true, want = false")
+ }
+ wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if !ep.WasBound() {
+ t.Error("got ep.WasBound() = false, want = true")
+ }
+ wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
+ wantInfo.BindAddr = bindAddr.Addr
+ wantInfo.BindNICID = bindAddr.NIC
+ if test.unicast {
+ wantInfo.RegisterNICID = nicID
+ } else {
+ wantInfo.RegisterNICID = bindAddr.NIC
+ }
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index cd8c99d41..1f30e5adb 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,7 +25,6 @@
package packet
import (
- "fmt"
"io"
"time"
@@ -60,49 +59,46 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
-
- // The following fields are used to manage the receive queue and are
- // protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // The following fields are used to manage the receive queue.
+ rcvMu sync.Mutex `state:"nosave"`
+ // +checklocks:rcvMu
+ rcvList packetList
+ // +checklocks:rcvMu
rcvBufSize int
- rcvClosed bool
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ // +checklocks:rcvMu
+ rcvClosed bool
+ // +checklocks:rcvMu
+ rcvDisabled bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ netProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ closed bool
+ // +checklocks:mu
+ bound bool
+ // +checklocks:mu
boundNIC tcpip.NICID
- // lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
- lastError tcpip.Error
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
- // frozen indicates if the packets should be delivered to the endpoint
- // during restore.
- frozen bool
+ // +checklocks:lastErrorMu
+ lastError tcpip.Error
}
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
+ stack: s,
cooked: cooked,
netProto: netProto,
waiterQueue: waiterQueue,
@@ -207,9 +203,52 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
return res, nil
}
-func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
- // TODO(gvisor.dev/issue/173): Implement.
- return 0, &tcpip.ErrInvalidOptionValue{}
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ if !ep.stack.PacketEndpointWriteSupported() {
+ return 0, &tcpip.ErrNotSupported{}
+ }
+
+ ep.mu.Lock()
+ closed := ep.closed
+ nicID := ep.boundNIC
+ proto := ep.netProto
+ ep.mu.Unlock()
+ if closed {
+ return 0, &tcpip.ErrClosedForSend{}
+ }
+
+ var remote tcpip.LinkAddress
+ if to := opts.To; to != nil {
+ remote = tcpip.LinkAddress(to.Addr)
+
+ if n := to.NIC; n != 0 {
+ nicID = n
+ }
+
+ if p := to.Port; p != 0 {
+ proto = tcpip.NetworkProtocolNumber(p)
+ }
+ }
+
+ if nicID == 0 {
+ return 0, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make(buffer.View, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
+
+ if err := func() tcpip.Error {
+ if ep.cooked {
+ return ep.stack.WritePacketToRemote(nicID, remote, proto, payloadBytes.ToVectorisedView())
+ }
+ return ep.stack.WriteRawPacket(nicID, proto, payloadBytes.ToVectorisedView())
+ }(); err != nil {
+ return 0, err
+ }
+ return int64(len(payloadBytes)), nil
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
@@ -244,8 +283,6 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- // TODO(gvisor.dev/issue/173): Add Bind support.
-
// "By default, all packets of the specified protocol type are passed
// to a packet socket. To get packets only from a specific interface
// use bind(2) specifying an address in a struct sockaddr_ll to bind
@@ -256,7 +293,8 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto {
// If the NIC being bound is the same then just return success.
return nil
}
@@ -266,12 +304,13 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.bound = false
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
ep.bound = true
ep.boundNIC = addr.NIC
+ ep.netProto = netProto
return nil
}
@@ -374,7 +413,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
rcvBufSize := ep.ops.GetReceiveBufferSize()
- if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
+ if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -383,78 +422,39 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
wasEmpty := ep.rcvBufSize == 0
- // Push new packet into receive list and increment the buffer size.
- var packet packet
- // TODO(gvisor.dev/issue/173): Return network protocol.
+ rcvdPkt := packet{
+ packetInfo: tcpip.LinkPacketInfo{
+ Protocol: netProto,
+ PktType: pkt.PktType,
+ },
+ senderAddr: tcpip.FullAddress{
+ NIC: nicID,
+ },
+ receivedAt: ep.stack.Clock().Now(),
+ }
+
if !pkt.LinkHeader().View().IsEmpty() {
- // Get info directly from the ethernet header.
hdr := header.Ethernet(pkt.LinkHeader().View())
- packet.senderAddr = tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.Address(hdr.SourceAddress()),
- }
- packet.packetInfo.Protocol = netProto
- packet.packetInfo.PktType = pkt.PktType
- } else {
- // Guess the would-be ethernet header.
- packet.senderAddr = tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.Address(localAddr),
- }
- packet.packetInfo.Protocol = netProto
- packet.packetInfo.PktType = pkt.PktType
+ rcvdPkt.senderAddr.Addr = tcpip.Address(hdr.SourceAddress())
}
if ep.cooked {
- // Cooked packets can simply be queued.
- switch pkt.PktType {
- case tcpip.PacketHost:
- packet.data = pkt.Data().ExtractVV()
- case tcpip.PacketOutgoing:
- // Strip Link Header.
- var combinedVV buffer.VectorisedView
- if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
- combinedVV.AppendView(v)
- }
- if v := pkt.TransportHeader().View(); !v.IsEmpty() {
- combinedVV.AppendView(v)
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- default:
- panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ // Cooked packet endpoints don't include the link-headers in received
+ // packets.
+ if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
+ rcvdPkt.data.AppendView(v)
}
-
- } else {
- // Raw packets need their ethernet headers prepended before
- // queueing.
- var linkHeader buffer.View
- if pkt.PktType != tcpip.PacketOutgoing {
- if pkt.LinkHeader().View().IsEmpty() {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
- }
- combinedVV := linkHeader.ToVectorisedView()
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- } else {
- packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ if v := pkt.TransportHeader().View(); !v.IsEmpty() {
+ rcvdPkt.data.AppendView(v)
}
+ rcvdPkt.data.Append(pkt.Data().ExtractVV())
+ } else {
+ // Raw packet endpoints include link-headers in received packets.
+ rcvdPkt.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
- packet.receivedAt = ep.stack.Clock().Now()
- ep.rcvList.PushBack(&packet)
- ep.rcvBufSize += packet.data.Size()
+ ep.rcvList.PushBack(&rcvdPkt)
+ ep.rcvBufSize += rcvdPkt.data.Size()
ep.rcvMu.Unlock()
ep.stats.PacketsReceived.Increment()
@@ -472,10 +472,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.netProto}
}
// Stats returns a pointer to the endpoint stats.
@@ -490,18 +488,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
-
-// freeze prevents any more packets from being delivered to the endpoint.
-func (ep *endpoint) freeze() {
- ep.mu.Lock()
- ep.frozen = true
- ep.mu.Unlock()
-}
-
-// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
-// new packets to be delivered again.
-func (ep *endpoint) thaw() {
- ep.mu.Lock()
- ep.frozen = false
- ep.mu.Unlock()
-}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index e729921db..d2768db7b 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -34,28 +34,26 @@ func (p *packet) loadReceivedAt(nsec int64) {
// saveData saves packet.data field.
func (p *packet) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads packet.data field.
func (p *packet) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- ep.freeze()
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+ ep.rcvDisabled = true
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- ep.thaw()
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
@@ -63,4 +61,8 @@ func (ep *endpoint) afterLoad() {
if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
panic(err)
}
+
+ ep.rcvMu.Lock()
+ ep.rcvDisabled = false
+ ep.rcvMu.Unlock()
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2eab09088..b7e97e218 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 1bce2769a..3040a445b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,7 @@
package raw
import (
+ "fmt"
"io"
"time"
@@ -34,6 +35,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -57,15 +60,19 @@ type rawPacket struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ ops tcpip.SocketOptions
+
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
@@ -74,20 +81,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- connected bool
- bound bool
- // route is the route to a remote network endpoint. It is set via
- // Connect(), and is valid only when conneted is true.
- route *stack.Route `state:"manual"`
- stats tcpip.TransportEndpointStats `state:"nosave"`
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
+ mu sync.RWMutex `state:"nosave"`
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
@@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
- return nil, &tcpip.ErrUnknownProtocol{}
- }
-
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
associated: associated,
}
@@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, transProto, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -132,12 +120,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
// headers included. Because they're write-only, We don't need to
// register with the stack.
if !associated {
- e.ops.SetReceiveBufferSize(0, false)
+ e.ops.SetReceiveBufferSize(0, false /* notify */)
e.waiterQueue = nil
return e, nil
}
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return nil, err
}
@@ -154,11 +142,17 @@ func (e *endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.closed || !e.associated {
+ if e.net.State() == transport.DatagramEndpointStateClosed {
return
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
+ e.net.Close()
+
+ if !e.associated {
+ return
+ }
+
+ e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e)
e.rcvMu.Lock()
defer e.rcvMu.Unlock()
@@ -170,15 +164,6 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- e.connected = false
-
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- e.closed = true
-
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -186,9 +171,7 @@ func (e *endpoint) Close() {
func (*endpoint) ModerateRecvBuf(int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.mu.Lock()
- defer e.mu.Unlock()
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ netProto := e.net.NetProto()
// We can create, but not write to, unassociated IPv6 endpoints.
- if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ if !e.associated && netProto == header.IPv6ProtocolNumber {
return 0, &tcpip.ErrInvalidOptionValue{}
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
return 0, &tcpip.ErrInvalidOptionValue{}
}
}
@@ -269,99 +253,25 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return 0, err
}
- payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- if e.closed {
- return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
- }
-
- payloadBytes := make([]byte, p.Len())
- if _, err := io.ReadFull(p, payloadBytes); err != nil {
- return nil, nil, nil, &tcpip.ErrBadBuffer{}
- }
- // If this is an unassociated socket and callee provided a nonzero
- // destination address, route using that address.
- if e.ops.GetHeaderIncluded() {
- ip := header.IPv4(payloadBytes)
- if !ip.IsValid(len(payloadBytes)) {
- return nil, nil, nil, &tcpip.ErrInvalidOptionValue{}
- }
- dstAddr := ip.DestinationAddress()
- // Update dstAddr with the address in the IP header, unless
- // opts.To is set (e.g. if sendto specifies a specific
- // address).
- if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
- opts.To = &tcpip.FullAddress{
- NIC: 0, // NIC is unset.
- Addr: dstAddr, // The address from the payload.
- Port: 0, // There are no ports here.
- }
- }
- }
-
- // Did the user caller provide a destination? If not, use the connected
- // destination.
- if opts.To == nil {
- // If the user doesn't specify a destination, they should have
- // connected to another address.
- if !e.connected {
- return nil, nil, nil, &tcpip.ErrDestinationRequired{}
- }
-
- e.route.Acquire()
-
- return payloadBytes, e.route, e.owner, nil
- }
-
- // The caller provided a destination. Reject destination address if it
- // goes through a different NIC than the endpoint was bound to.
- nic := opts.To.NIC
- if e.bound && nic != 0 && nic != e.BindNICID {
- return nil, nil, nil, &tcpip.ErrNoRoute{}
- }
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
- // Find the route to the destination. If BindAddress is 0,
- // FindRoute will choose an appropriate source address.
- route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
- if err != nil {
- return nil, nil, nil, err
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
- return payloadBytes, route, e.owner, nil
- }()
- if err != nil {
+ if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil {
return 0, err
}
- defer route.Release()
-
- if e.ops.GetHeaderIncluded() {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
- return 0, err
- }
- } else {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(route.MaxHeaderLength()),
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- pkt.Owner = owner
- if err := route.WritePacket(stack.NetworkHeaderParams{
- Protocol: e.TransProto,
- TTL: route.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- return 0, err
- }
- }
return int64(len(payloadBytes)), nil
}
@@ -373,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ netProto := e.net.NetProto()
+
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
return &tcpip.ErrAddressFamilyNotSupported{}
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.closed {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- nic := addr.NIC
- if e.bound {
- if e.BindNICID == 0 {
- // If we're bound, but not to a specific NIC, the NIC
- // in addr will be used. Nothing to do here.
- } else if addr.NIC == 0 {
- // If we're bound to a specific NIC, but addr doesn't
- // specify a NIC, use the bound NIC.
- nic = e.BindNICID
- } else if addr.NIC != e.BindNICID {
- // We're bound and addr specifies a NIC. They must be
- // the same.
- return &tcpip.ErrInvalidEndpointState{}
- }
- }
-
- // Find a route to the destination.
- route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false)
- if err != nil {
- return err
- }
-
- if e.associated {
- // Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- route.Release()
- return err
+ return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ if e.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ return err
+ }
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = nic
- }
- if e.route != nil {
- // If the endpoint was previously connected then release any previous route.
- e.route.Release()
- }
- e.route = route
- e.connected = true
-
- return nil
+ return nil
+ })
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if !e.connected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return &tcpip.ErrNotConnected{}
}
return nil
@@ -450,33 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // If a local address was specified, verify that it's valid.
- if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
+ return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error {
+ if !e.associated {
+ return nil
+ }
- if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = addr.NIC
- e.BindNICID = addr.NIC
- }
-
- e.BindAddr = addr.Addr
- e.bound = true
-
- return nil
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
+ return nil
+ })
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ a := e.net.GetLocalAddress()
+ // Linux returns the protocol in the port field.
+ a.Port = uint16(e.transProto)
+ return a, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -509,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
default:
- return &tcpip.ErrUnknownProtocolOption{}
+ return e.net.SetSockOpt(opt)
}
}
-func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ return e.net.SetSockOptInt(opt, v)
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -536,100 +402,108 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
return v, nil
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- e.mu.RLock()
- e.rcvMu.Lock()
+ notifyReadableEvents := func() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return false
+ }
- // Drop the packet if our buffer is currently full or if this is an unassociated
- // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
- // See: https://man7.org/linux/man-pages/man7/raw.7.html
- //
- // An IPPROTO_RAW socket is send only. If you really want to receive
- // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
- // Note that packet sockets don't reassemble IP fragments, unlike raw
- // sockets.
- if e.rcvClosed || !e.associated {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ClosedReceiver.Increment()
- return
- }
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return false
+ }
- rcvBufSize := e.ops.GetReceiveBufferSize()
- if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
- return
- }
+ srcAddr := pkt.Network().SourceAddress()
+ info := e.net.Info()
+
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if info.ID.RemoteAddress != srcAddr {
+ return false
+ }
- remoteAddr := pkt.Network().SourceAddress()
+ // Connected sockets may also have been bound to a specific
+ // address/NIC.
+ fallthrough
+ case transport.DatagramEndpointStateBound:
+ // If bound to a NIC, only accept data for that NIC.
+ if info.BindNICID != 0 && info.BindNICID != pkt.NICID {
+ return false
+ }
- if e.bound {
- // If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
- // If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
+ // If bound to an address, only accept data for that address.
+ if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ return false
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- }
- // If connected, only accept packets from the remote address we
- // connected to.
- if e.connected && e.route.RemoteAddress() != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
+ wasEmpty := e.rcvBufSize == 0
- wasEmpty := e.rcvBufSize == 0
+ // Push new packet into receive list and increment the buffer size.
+ packet := &rawPacket{
+ senderAddr: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: srcAddr,
+ },
+ }
- // Push new packet into receive list and increment the buffer size.
- packet := &rawPacket{
- senderAddr: tcpip.FullAddress{
- NIC: pkt.NICID,
- Addr: remoteAddr,
- },
- }
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
+ // overlapping slices.
+ var combinedVV buffer.VectorisedView
+ if info.NetProto == header.IPv4ProtocolNumber {
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(network)+len(transport))
+ headers = append(headers, network...)
+ headers = append(headers, transport...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
+ combinedVV.Append(pkt.Data().ExtractVV())
+ packet.data = combinedVV
+ packet.receivedAt = e.stack.Clock().Now()
- // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
- // We copy headers' underlying bytes because pkt.*Header may point to
- // the middle of a slice, and another struct may point to the "outer"
- // slice. Save/restore doesn't support overlapping slices and will fail.
- var combinedVV buffer.VectorisedView
- if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
- combinedVV = headers.ToVectorisedView()
- } else {
- combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- packet.receivedAt = e.stack.Clock().Now()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.stats.PacketsReceived.Increment()
- e.rcvList.PushBack(packet)
- e.rcvBufSize += packet.data.Size()
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stats.PacketsReceived.Increment()
- // Notify waiters that there's data to be read.
- if wasEmpty {
+ // Notify waiters that there is data to be read now.
+ return wasEmpty
+ }()
+
+ if notifyReadableEvents {
e.waiterQueue.Notify(waiter.ReadableEvents)
}
}
@@ -641,10 +515,7 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ ret := e.net.Info()
return &ret
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 39669b445..e74713064 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,6 +15,7 @@
package raw
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.net.Resume(s)
+
e.thaw()
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // If the endpoint is connected, re-connect.
- if e.connected {
- var err tcpip.Error
- // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
- // remote address. We used to pass e.remote.RemoteAddress which was
- // effectively the empty address but since moving e.route to hold a pointer
- // to a route instead of the route by value, we pass the empty address
- // directly. Obviously this was always wrong since we should provide the
- // remote address we were connected to, to properly restore the route.
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
- if err != nil {
- panic(err)
- }
- }
-
- // If the endpoint is bound, re-bind.
- if e.bound {
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
if e.associated {
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- panic(err)
+ netProto := e.net.NetProto()
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
}
}
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 8436d2cf0..5148fe157 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -68,6 +68,7 @@ go_library(
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
@@ -96,6 +97,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d807b13b7..03c9fafa1 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 {
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
+ stack *stack.Stack
+ protocol *protocol
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
@@ -119,9 +120,10 @@ func timeStamp(clock tcpip.Clock) uint32 {
}
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stk,
+ protocol: protocol,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6Only: v6Only,
@@ -201,7 +203,7 @@ func (l *listenContext) useSynCookies() bool {
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -213,7 +215,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
return nil, err
}
- n := newEndpoint(l.stack, netProto, queue)
+ n := newEndpoint(l.stack, l.protocol, netProto, queue)
n.ops.SetV6Only(l.v6Only)
n.TransportEndpointInfo.ID = s.id
n.boundNICID = s.nicID
@@ -244,10 +246,10 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
// On success, a handshake h is returned with h.ep.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
+func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed())
+ isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret)
ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
@@ -323,14 +325,16 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// established endpoint is returned with e.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
+func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
h, err := l.startHandshake(s, opts, queue, owner)
if err != nil {
return nil, err
}
ep := h.ep
- if err := h.complete(); err != nil {
+ // N.B. the endpoint is generated above by startHandshake, and will be
+ // returned locked. This first call is forced.
+ if err := h.complete(); err != nil { // +checklocksforce
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.stats.FailedConnectionAttempts.Increment()
l.cleanupFailedHandshake(h)
@@ -364,6 +368,7 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
@@ -492,7 +497,7 @@ func (e *endpoint) notifyAborted() {
// cookies to accept connections.
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error {
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error {
defer s.decRef()
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
@@ -504,7 +509,9 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
go func() {
- if err := h.complete(); err != nil {
+ // Note that startHandshake returns a locked endpoint. The
+ // force call here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
@@ -576,7 +583,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !ctx.useSynCookies() {
s.incRef()
atomic.AddInt32(&e.synRcvdCount, 1)
- return e.handleSynSegment(ctx, s, &opts)
+ return e.handleSynSegment(ctx, s, opts)
}
route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
@@ -595,10 +602,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
+ if opts.TS {
+ offset := e.protocol.tsOffset(s.dstAddr, s.srcAddr)
+ now := e.stack.Clock().NowMonotonic()
+ synOpts.TSVal = offset.TSVal(now)
+ }
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
fields := tcpFields{
id: s.id,
@@ -665,7 +676,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
- rcvdSynOptions := &header.TCPSynOptions{
+ rcvdSynOptions := header.TCPSynOptions{
MSS: mssTable[data],
// Disable Window scaling as original SYN is
// lost.
@@ -720,25 +731,22 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
n.isRegistered = true
-
- // clear the tsOffset for the newly created
- // endpoint as the Timestamp was already
- // randomly offset when the original SYN-ACK was
- // sent above.
- n.TSOffset = 0
+ n.TSOffset = n.protocol.tsOffset(s.dstAddr, s.srcAddr)
// Switch state to connected.
n.isConnectNotified = true
- n.transitionToStateEstablishedLocked(&handshake{
- ep: n,
- iss: iss,
- ackNum: irs + 1,
- rcvWnd: seqnum.Size(n.initialReceiveWindow()),
- sndWnd: s.window,
- rcvWndScale: e.rcvWndScaleForHandshake(),
- sndWndScale: rcvdSynOptions.WS,
- mss: rcvdSynOptions.MSS,
- })
+ h := handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ sampleRTTWithTSOnly: true,
+ }
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -774,7 +782,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.ops.GetV6Only()
- ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
+ ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 2137ebc25..5d8e18484 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -105,6 +105,11 @@ type handshake struct {
// sendSYNOpts is the cached values for the SYN options to be sent.
sendSYNOpts header.TCPSynOptions
+
+ // sampleRTTWithTSOnly is true when the segment was retransmitted or we can't
+ // tell; then RTT can only be sampled when the incoming segment has timestamp
+ // options enabled.
+ sampleRTTWithTSOnly bool
}
func (e *endpoint) newHandshake() *handshake {
@@ -117,10 +122,12 @@ func (e *endpoint) newHandshake() *handshake {
h.resetState()
// Store reference to handshake state in endpoint.
e.h = h
+ // By the time handshake is created, e.ID is already initialized.
+ e.TSOffset = e.protocol.tsOffset(e.ID.LocalAddress, e.ID.RemoteAddress)
return h
}
-func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) *handshake {
h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
@@ -150,20 +157,23 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret)
}
// generateSecureISN generates a secure Initial Sequence number based on the
// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value {
isnHasher := jenkins.Sum32(seed)
- isnHasher.Write([]byte(id.LocalAddress))
- isnHasher.Write([]byte(id.RemoteAddress))
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = isnHasher.Write([]byte(id.LocalAddress))
+ _, _ = isnHasher.Write([]byte(id.RemoteAddress))
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
// The time period here is 64ns. This is similar to what linux uses
// generate a sequence number that overlaps less than one
// time per MSL (2 minutes).
@@ -190,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -251,10 +261,10 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
rcvSynOpts := parseSynSegmentOptions(s)
// Remember if the Timestamp option was negotiated.
- h.ep.maybeEnableTimestamp(&rcvSynOpts)
+ h.ep.maybeEnableTimestamp(rcvSynOpts)
// Remember if the SACKPermitted option was negotiated.
- h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+ h.ep.maybeEnableSACKPermitted(rcvSynOpts)
// Remember the sequence we'll ack from now on.
h.ackNum = s.sequenceNumber + 1
@@ -266,8 +276,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// and the handshake is completed.
if s.flags.Contains(header.TCPFlagAck) {
h.state = handshakeCompleted
-
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -283,7 +292,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
@@ -353,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: h.ep.SendTSOk,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.SACKPermitted,
MSS: h.ep.amss,
@@ -402,9 +411,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
if h.ep.SendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
+
h.state = handshakeCompleted
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -480,7 +490,7 @@ func (h *handshake) start() {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
@@ -511,6 +521,7 @@ func (h *handshake) start() {
}
// complete completes the TCP 3-way handshake initiated by h.start().
+// +checklocks:h.ep.mu
func (h *handshake) complete() tcpip.Error {
// Set up the wakers.
var s sleep.Sleeper
@@ -556,6 +567,10 @@ func (h *handshake) complete() tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, h.sendSYNOpts)
+ // If we have ever retransmitted the SYN-ACK or
+ // SYN segment, we should only measure RTT if
+ // TS option is present.
+ h.sampleRTTWithTSOnly = true
}
case wakerForNotification:
@@ -599,6 +614,40 @@ func (h *handshake) complete() tcpip.Error {
return nil
}
+// transitionToStateEstablisedLocked transitions the endpoint of the handshake
+// to an established state given the last segment received from peer. It also
+// initializes sender/receiver.
+func (h *handshake) transitionToStateEstablishedLocked(s *segment) {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ h.ep.snd = newSender(h.ep, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ now := h.ep.stack.Clock().NowMonotonic()
+
+ var rtt time.Duration
+ if h.ep.SendTSOk && s.parsedOptions.TSEcr != 0 {
+ rtt = h.ep.elapsed(now, s.parsedOptions.TSEcr)
+ }
+ if !h.sampleRTTWithTSOnly && rtt == 0 {
+ rtt = now.Sub(h.startTime)
+ }
+
+ if rtt > 0 {
+ h.ep.snd.updateRTO(rtt)
+ }
+
+ h.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ h.ep.rcv = newReceiver(h.ep, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ h.ep.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
+ h.ep.rcvQueueInfo.rcvQueueMu.Unlock()
+
+ h.ep.setEndpointState(StateEstablished)
+}
+
type backoffTimer struct {
timeout time.Duration
maxTimeout time.Duration
@@ -872,7 +921,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
+ offset += header.EncodeTSOption(e.tsValNow(), e.recentTimestamp(), options[offset:])
}
if e.SACKPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -909,30 +958,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
return err
}
-func (e *endpoint) handleWrite() {
- e.sndQueueInfo.sndQueueMu.Lock()
- next := e.drainSendQueueLocked()
- e.sndQueueInfo.sndQueueMu.Unlock()
-
- e.sendData(next)
-}
-
-// Move packets from send queue to send list.
-//
-// Precondition: e.sndBufMu must be locked.
-func (e *endpoint) drainSendQueueLocked() *segment {
- first := e.sndQueueInfo.sndQueue.Front()
- if first != nil {
- e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue)
- e.sndQueueInfo.SndBufInQueue = 0
- }
- return first
-}
-
// Precondition: e.mu must be locked.
func (e *endpoint) sendData(next *segment) {
// Initialize the next segment to write if it's currently nil.
if e.snd.writeNext == nil {
+ if next == nil {
+ return
+ }
e.snd.writeNext = next
}
@@ -940,17 +972,6 @@ func (e *endpoint) sendData(next *segment) {
e.snd.sendData()
}
-func (e *endpoint) handleClose() {
- if !e.EndpointState().connected() {
- return
- }
- // Drain the send queue.
- e.handleWrite()
-
- // Mark send side as closed.
- e.snd.Closed = true
-}
-
// resetConnectionLocked puts the endpoint in an error state with the given
// error code and sends a RST if and only if the error is not ErrConnectionReset
// indicating that the connection is being reset due to receiving a RST. This
@@ -992,26 +1013,6 @@ func (e *endpoint) completeWorkerLocked() {
}
}
-// transitionToStateEstablisedLocked transitions a given endpoint
-// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver.
-func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
- e.rcvQueueInfo.rcvQueueMu.Unlock()
-
- e.setEndpointState(StateEstablished)
-}
-
// transitionToStateCloseLocked ensures that the endpoint is
// cleaned up from the transport demuxer, "before" moving to
// StateClose. This will ensure that no packet will be
@@ -1130,7 +1131,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- if e.EndpointState().closed() {
+ if state := e.EndpointState(); state.closed() || state == StateTimeWait {
return nil
}
s := e.segmentQueue.dequeue()
@@ -1311,42 +1312,45 @@ func (e *endpoint) disableKeepaliveTimer() {
e.keepalive.Unlock()
}
-// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
-// goroutine and is responsible for sending segments and handling received
-// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
- e.mu.Lock()
- var closeTimer tcpip.Timer
- var closeWaker sleep.Waker
-
- epilogue := func() {
- // e.mu is expected to be hold upon entering this section.
- if e.snd != nil {
- e.snd.resendTimer.cleanup()
- e.snd.probeTimer.cleanup()
- e.snd.reorderTimer.cleanup()
- }
+// protocolMainLoopDone is called at the end of protocolMainLoop.
+// +checklocksrelease:e.mu
+func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) {
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ e.snd.probeTimer.cleanup()
+ e.snd.reorderTimer.cleanup()
+ }
- if closeTimer != nil {
- closeTimer.Stop()
- }
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
- e.completeWorkerLocked()
+ e.completeWorkerLocked()
- if e.drainDone != nil {
- close(e.drainDone)
- }
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
- e.mu.Unlock()
+ e.mu.Unlock()
- e.drainClosingSegmentQueue()
+ e.drainClosingSegmentQueue()
- // When the protocol loop exits we should wake up our waiters.
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
- }
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+}
+
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) {
+ var (
+ closeTimer tcpip.Timer
+ closeWaker sleep.Waker
+ )
+ e.mu.Lock()
if handshake {
- if err := e.h.complete(); err != nil {
+ if err := e.h.complete(); err != nil { // +checklocksforce
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -1355,9 +1359,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.hardError = err
e.workerCleanup = true
- // Lock released below.
- epilogue()
- return err
+ e.protocolMainLoopDone(closeTimer)
+ return
}
}
@@ -1402,14 +1405,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
{
w: &e.sndQueueInfo.sndWaker,
f: func() tcpip.Error {
- e.handleWrite()
- return nil
- },
- },
- {
- w: &e.sndQueueInfo.sndCloseWaker,
- f: func() tcpip.Error {
- e.handleClose()
+ e.sendData(nil /* next */)
return nil
},
},
@@ -1507,7 +1503,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
- e.mu.Unlock()
+ e.mu.Unlock() // +checklocksforce
<-e.undrain
e.mu.Lock()
}
@@ -1568,8 +1564,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if err != nil {
e.resetConnectionLocked(err)
}
- // Lock released below.
- epilogue()
}
loop:
@@ -1593,7 +1587,8 @@ loop:
// just want to terminate the loop and cleanup the
// endpoint.
cleanupOnError(nil)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
case StateTimeWait:
fallthrough
case StateClose:
@@ -1601,7 +1596,8 @@ loop:
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
}
}
}
@@ -1624,21 +1620,19 @@ loop:
// Handle any StateError transition from StateTimeWait.
if e.EndpointState() == StateError {
cleanupOnError(nil)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
}
e.transitionToStateCloseLocked()
- // Lock released below.
- epilogue()
+ e.protocolMainLoopDone(closeTimer)
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
reuseTW()
}
-
- return nil
}
// handleTimeWaitSegments processes segments received during TIME_WAIT
@@ -1700,6 +1694,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
// should be executed after releasing the endpoint registrations. This is
// done in cases where a new SYN is received during TIME_WAIT that carries
// a sequence number larger than one see on the connection.
+// +checklocks:e.mu
func (e *endpoint) doTimeWait() (twReuse func()) {
// Trigger a 2 * MSL time wait state. During this period
// we will drop all incoming segments.
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index dff7cb89c..7d110516b 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -127,7 +127,7 @@ func (p *processor) start(wg *sync.WaitGroup) {
case !ep.segmentQueue.empty():
p.epQ.enqueue(ep)
}
- ep.mu.Unlock()
+ ep.mu.Unlock() // +checklocksforce
} else {
ep.newSegmentWaker.Assert()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a27e2110b..d2b8f298f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -20,7 +20,6 @@ import (
"fmt"
"io"
"math"
- "math/rand"
"runtime"
"strings"
"sync/atomic"
@@ -293,16 +292,9 @@ type sndQueueInfo struct {
sndQueueMu sync.Mutex `state:"nosave"`
stack.TCPSndBufState
- // sndQueue holds segments that are ready to be sent.
- sndQueue segmentList `state:"wait"`
-
- // sndWaker is used to signal the protocol goroutine when segments are
- // added to the `sndQueue`.
+ // sndWaker is used to signal the protocol goroutine when there may be
+ // segments that need to be sent.
sndWaker sleep.Waker `state:"manual"`
-
- // sndCloseWaker is used to notify the protocol goroutine when the send
- // side is closed.
- sndCloseWaker sleep.Waker `state:"manual"`
}
// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata.
@@ -385,6 +377,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
+ protocol *protocol `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
@@ -485,7 +478,7 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
- // tcpRecovery is the loss deteoction algorithm used by TCP.
+ // tcpRecovery is the loss recovery algorithm used by TCP.
tcpRecovery tcpip.TCPRecovery
// sack holds TCP SACK related information for this endpoint.
@@ -671,6 +664,7 @@ func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The assumption behind spinning here being that background packet processing
// should not be holding the lock for long and spinning reduces latency as we
// avoid an expensive sleep/wakeup of of the syscall goroutine).
+// +checklocksacquire:e.mu
func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
@@ -690,7 +684,7 @@ func (e *endpoint) LockUser() {
continue
}
atomic.StoreUint32(&e.ownedByUser, 1)
- return
+ return // +checklocksforce
}
}
@@ -707,7 +701,7 @@ func (e *endpoint) LockUser() {
// protocol goroutine altogether.
//
// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) UnlockUser() {
// Lock segment queue before checking so that we avoid a race where
// segments can be queued between the time we check if queue is empty
@@ -743,12 +737,13 @@ func (e *endpoint) UnlockUser() {
}
// StopWork halts packet processing. Only to be used in tests.
+// +checklocksacquire:e.mu
func (e *endpoint) StopWork() {
e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) ResumeWork() {
e.mu.Unlock()
}
@@ -759,7 +754,7 @@ func (e *endpoint) ResumeWork() {
//
// Precondition: e.mu must be held to call this method.
func (e *endpoint) setEndpointState(state EndpointState) {
- oldstate := EndpointState(atomic.LoadUint32(&e.state))
+ oldstate := EndpointState(atomic.SwapUint32(&e.state, uint32(state)))
switch state {
case StateEstablished:
e.stack.Stats().TCP.CurrentEstablished.Increment()
@@ -776,7 +771,6 @@ func (e *endpoint) setEndpointState(state EndpointState) {
e.stack.Stats().TCP.CurrentEstablished.Decrement()
}
}
- atomic.StoreUint32(&e.state, uint32(state))
}
// EndpointState returns the current state of the endpoint.
@@ -809,9 +803,10 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
+ stack: s,
+ protocol: protocol,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: header.TCPProtocolNumber,
@@ -875,14 +870,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.maxSynRetries = uint8(synRetries)
}
- s.TransportProtocolOption(ProtocolNumber, &e.tcpRecovery)
-
if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset(e.stack.Rand())
+
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
@@ -1487,87 +1480,101 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
return avail, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
+// readFromPayloader reads a slice from the Payloader.
+// +checklocks:e.mu
+// +checklocks:e.sndQueueInfo.sndQueueMu
+func (e *endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions, avail int) ([]byte, tcpip.Error) {
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndQueueInfo.sndQueueMu.Unlock()
+ defer e.sndQueueInfo.sndQueueMu.Lock()
- e.LockUser()
- defer e.UnlockUser()
+ e.UnlockUser()
+ defer e.LockUser()
+ }
- nextSeg, n, err := func() (*segment, int, tcpip.Error) {
- e.sndQueueInfo.sndQueueMu.Lock()
- defer e.sndQueueInfo.sndQueueMu.Unlock()
+ // Fetch data.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return nil, nil
+ }
+ v := make([]byte, avail)
+ n, err := p.Read(v)
+ if err != nil && err != io.EOF {
+ return nil, &tcpip.ErrBadBuffer{}
+ }
+ return v[:n], nil
+}
+
+// queueSegment reads data from the payloader and returns a segment to be sent.
+// +checklocks:e.mu
+func (e *endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) {
+ e.sndQueueInfo.sndQueueMu.Lock()
+ defer e.sndQueueInfo.sndQueueMu.Unlock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return nil, 0, err
+ }
+
+ v, err := e.readFromPayloader(p, opts, avail)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Do not queue zero length segments.
+ if len(v) == 0 {
+ return nil, 0, nil
+ }
+ if !opts.Atomic {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.stats.WriteErrors.WriteClosed.Increment()
return nil, 0, err
}
- v, err := func() ([]byte, tcpip.Error) {
- // We can release locks while copying data.
- //
- // This is not possible if atomic is set, because we can't allow the
- // available buffer space to be consumed by some other caller while we
- // are copying data in.
- if !opts.Atomic {
- e.sndQueueInfo.sndQueueMu.Unlock()
- defer e.sndQueueInfo.sndQueueMu.Lock()
-
- e.UnlockUser()
- defer e.LockUser()
- }
-
- // Fetch data.
- if l := p.Len(); l < avail {
- avail = l
- }
- if avail == 0 {
- return nil, nil
- }
- v := make([]byte, avail)
- n, err := p.Read(v)
- if err != nil && err != io.EOF {
- return nil, &tcpip.ErrBadBuffer{}
- }
- return v[:n], nil
- }()
- if len(v) == 0 || err != nil {
- return nil, 0, err
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
+ }
- if !opts.Atomic {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- avail, err := e.isEndpointWritableLocked()
- if err != nil {
- e.stats.WriteErrors.WriteClosed.Increment()
- return nil, 0, err
- }
+ // Add data to the send queue.
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
+ e.sndQueueInfo.SndBufUsed += len(v)
+ e.snd.writeList.PushBack(s)
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- }
+ return s, len(v), nil
+}
- // Add data to the send queue.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
- e.sndQueueInfo.SndBufUsed += len(v)
- e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v))
- e.sndQueueInfo.sndQueue.PushBack(s)
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.LockUser()
+ defer e.UnlockUser()
- return e.drainSendQueueLocked(), len(v), nil
- }()
// Return if either we didn't queue anything or if an error occurred while
// attempting to queue data.
+ nextSeg, n, err := e.queueSegment(p, opts)
if n == 0 || err != nil {
return 0, err
}
+
e.sendData(nextSeg)
return int64(n), nil
}
@@ -1711,6 +1718,27 @@ func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) {
return rcvBufSz
}
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *endpoint) OnSetSendBufferSize(sz int64) int64 {
+ atomic.StoreUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled, 1)
+ return sz
+}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *endpoint) WakeupWriters() {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ sendBufferSize := e.getSendBufferSize()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ notify := (sendBufferSize - e.sndQueueInfo.SndBufUsed) >= e.sndQueueInfo.SndBufUsed>>1
+ e.sndQueueInfo.sndQueueMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.WritableEvents)
+ }
+}
+
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
@@ -2171,7 +2199,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h := jenkins.Sum32(e.stack.Seed())
+ h := jenkins.Sum32(e.protocol.portOffsetSecret)
for _, s := range [][]byte{
[]byte(e.ID.LocalAddress),
[]byte(e.ID.RemoteAddress),
@@ -2314,7 +2342,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// connection setting here.
if !handshake {
e.segmentQueue.mu.Lock()
- for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} {
+ for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.TransportEndpointInfo.ID
e.sndQueueInfo.sndWaker.Assert()
@@ -2323,6 +2351,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.setEndpointState(StateEstablished)
+ // Set the new auto tuned send buffer size after entering
+ // established state.
+ e.ops.SetSendBufferSize(e.computeTCPSendBufferSize(), false /* notify */)
}
if run {
@@ -2372,6 +2403,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
+ // Wake up any readers that maybe waiting for the stream to become
+ // readable.
+ e.waiterQueue.Notify(waiter.ReadableEvents)
}
// Close for write.
@@ -2388,12 +2422,20 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Queue fin segment.
s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil)
- e.sndQueueInfo.sndQueue.PushBack(s)
- e.sndQueueInfo.SndBufInQueue++
+ e.snd.writeList.PushBack(s)
// Mark endpoint as closed.
e.sndQueueInfo.SndClosed = true
e.sndQueueInfo.sndQueueMu.Unlock()
- e.handleClose()
+
+ // Drain the send queue.
+ e.sendData(s)
+
+ // Mark send side as closed.
+ e.snd.Closed = true
+
+ // Wake up any writers that maybe waiting for the stream to become
+ // writable.
+ e.waiterQueue.Notify(waiter.WritableEvents)
}
return nil
@@ -2501,6 +2543,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
+// +checklocksrelease:e.mu
func (e *endpoint) startAcceptedLoop() {
e.workerRunning = true
e.mu.Unlock()
@@ -2745,13 +2788,20 @@ func (e *endpoint) updateSndBufferUsage(v int) {
e.sndQueueInfo.sndQueueMu.Lock()
notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1
e.sndQueueInfo.SndBufUsed -= v
+
+ // Get the new send buffer size with auto tuning, but do not set it
+ // unless we decide to notify the writers.
+ newSndBufSz := e.computeTCPSendBufferSize()
+
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1
+ notify = notify && e.sndQueueInfo.SndBufUsed < int(newSndBufSz)>>1
e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
+ // Set the new send buffer size calculated from auto tuning.
+ e.ops.SetSendBufferSize(newSndBufSz, false /* notify */)
e.waiterQueue.Notify(waiter.WritableEvents)
}
}
@@ -2855,46 +2905,29 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
// the SYN options indicate that timestamp option was negotiated. It also
// initializes the recentTS with the value provided in synOpts.TSval.
-func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableTimestamp(synOpts header.TCPSynOptions) {
if synOpts.TS {
e.SendTSOk = true
e.setRecentTimestamp(synOpts.TSVal)
}
}
-// timestamp returns the timestamp value to be used in the TSVal field of the
-// timestamp option for outgoing TCP segments for a given endpoint.
-func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset)
+func (e *endpoint) tsVal(now tcpip.MonotonicTime) uint32 {
+ return e.TSOffset.TSVal(now)
}
-// tcpTimeStamp returns a timestamp offset by the provided offset. This is
-// not inlined above as it's used when SYN cookies are in use and endpoint
-// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
- d := curTime.Sub(tcpip.MonotonicTime{})
- return uint32(d.Milliseconds()) + offset
+func (e *endpoint) tsValNow() uint32 {
+ return e.tsVal(e.stack.Clock().NowMonotonic())
}
-// timeStampOffset returns a randomized timestamp offset to be used when sending
-// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset(rng *rand.Rand) uint32 {
- // Initialize a random tsOffset that will be added to the recentTS
- // everytime the timestamp is sent when the Timestamp option is enabled.
- //
- // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
- // why this is required.
- //
- // NOTE: This is not completely to spec as normally this should be
- // initialized in a manner analogous to how sequence numbers are
- // randomized per connection basis. But for now this is sufficient.
- return rng.Uint32()
+func (e *endpoint) elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return e.TSOffset.Elapsed(now, tsEcr)
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
-func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableSACKPermitted(synOpts header.TCPSynOptions) {
var v tcpip.TCPSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
@@ -2902,6 +2935,7 @@ func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
}
if bool(v) && synOpts.SACKPermitted {
e.SACKPermitted = true
+ e.stack.TransportProtocolOption(ProtocolNumber, &e.tcpRecovery)
}
}
@@ -3072,3 +3106,36 @@ func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOpti
Max: ss.Max,
}
}
+
+// computeTCPSendBufferSize implements auto tuning of send buffer size and
+// returns the new send buffer size.
+func (e *endpoint) computeTCPSendBufferSize() int64 {
+ curSndBufSz := int64(e.getSendBufferSize())
+
+ // Auto tuning is disabled when the user explicitly sets the send
+ // buffer size with SO_SNDBUF option.
+ if disabled := atomic.LoadUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled); disabled == 1 {
+ return curSndBufSz
+ }
+
+ const packetOverheadFactor = 2
+ curMSS := e.snd.MaxPayloadSize
+ numSeg := InitialCwnd
+ if numSeg < e.snd.SndCwnd {
+ numSeg = e.snd.SndCwnd
+ }
+
+ // SndCwnd indicates the number of segments that can be sent. This means
+ // that the sender can send upto #SndCwnd segments and the send buffer
+ // size should be set to SndCwnd*MSS to accommodate sending of all the
+ // segments.
+ newSndBufSz := int64(numSeg * curMSS * packetOverheadFactor)
+ if newSndBufSz < curSndBufSz {
+ return curSndBufSz
+ }
+ if ss := GetTCPSendBufferLimits(e.stack); int64(ss.Max) < newSndBufSz {
+ newSndBufSz = int64(ss.Max)
+ }
+
+ return newSndBufSz
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 952ccacdd..f2e8b3840 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
snd.probeTimer.init(s.Clock(), &snd.probeWaker)
}
e.stack = s
+ e.protocol = protocolFromStack(s)
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := EndpointState(e.origEndpointState)
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 65c86823a..128ef09e3 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
- listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
}
}
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
@@ -164,8 +164,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
return nil, err
}
- // Start the protocol goroutine.
- ep.startAcceptedLoop()
+ // Start the protocol goroutine. Note that the endpoint is returned
+ // from performHandshake locked.
+ ep.startAcceptedLoop() // +checklocksforce
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2fc282e73..e4410ad93 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,8 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -49,10 +51,6 @@ const (
// MaxBufferSize is the largest size a receive/send buffer can grow to.
MaxBufferSize = 4 << 20 // 4MB
- // MaxUnprocessedSegments is the maximum number of unprocessed segments
- // that can be queued for a given endpoint.
- MaxUnprocessedSegments = 300
-
// DefaultTCPLingerTimeout is the amount of time that sockets linger in
// FIN_WAIT_2 state before being marked closed.
DefaultTCPLingerTimeout = 60 * time.Second
@@ -96,6 +94,11 @@ type protocol struct {
maxRetries uint32
synRetries uint8
dispatcher dispatcher
+
+ // The following secrets are initialized once and stay unchanged after.
+ seqnumSecret uint32
+ portOffsetSecret uint32
+ tsOffsetSecret uint32
}
// Number returns the tcp protocol number.
@@ -105,7 +108,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
// NewEndpoint creates a new tcp endpoint.
func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
- return newEndpoint(p.stack, netProto, waiterQueue), nil
+ return newEndpoint(p.stack, p, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
@@ -156,6 +159,24 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
return stack.UnknownDestinationPacketHandled
}
+func (p *protocol) tsOffset(src, dst tcpip.Address) tcp.TSOffset {
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // TODO(https://gvisor.dev/issues/6473): This is not really secure as
+ // it does not use the recommended algorithm linked above.
+ h := jenkins.Sum32(p.tsOffsetSecret)
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = h.Write([]byte(src))
+ _, _ = h.Write([]byte(dst))
+ return tcp.NewTSOffset(h.Sum32())
+}
+
// replyWithReset replies to the given segment with a reset segment.
//
// If the passed TTL is 0, then the route's default TTL will be used.
@@ -292,22 +313,26 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
case *tcpip.TCPMinRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.minRTO = MinRTO
+ } else if minRTO := time.Duration(*v); minRTO <= p.maxRTO {
+ p.minRTO = minRTO
} else {
- p.minRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.maxRTO = MaxRTO
+ } else if maxRTO := time.Duration(*v); maxRTO >= p.minRTO {
+ p.maxRTO = maxRTO
} else {
- p.maxRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRetriesOption:
@@ -478,9 +503,16 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
- // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection.
- recovery: 0,
+ recovery: tcpip.TCPRACKLossDetection,
+ seqnumSecret: s.Rand().Uint32(),
+ portOffsetSecret: s.Rand().Uint32(),
+ tsOffsetSecret: s.Rand().Uint32(),
}
p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
+
+// protocolFromStack retrieves the tcp.protocol instance from stack s.
+func protocolFromStack(s *stack.Stack) *protocol {
+ return s.TransportProtocolInstance(ProtocolNumber).(*protocol)
+}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 0da4eafaa..3b055c294 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -80,7 +80,6 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime)
- tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
@@ -92,7 +91,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// step 2
if seg.xmitCount > 1 {
if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
- if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) {
+ if ackSeg.parsedOptions.TSEcr < rc.snd.ep.tsVal(seg.xmitTime) {
return
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 661ca604a..90e493978 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 {
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
r.PendingBufUsed += s.segMemSize()
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
@@ -559,7 +559,6 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// (2) returns to TIME-WAIT state if the SYN turns out
// to be an old duplicate".
if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
-
return false, true
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 72d58dcff..2fabf1594 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -382,6 +382,9 @@ func (s *sender) updateRTO(rtt time.Duration) {
if s.RTO < s.minRTO {
s.RTO = s.minRTO
}
+ if s.RTO > s.maxRTO {
+ s.RTO = s.maxRTO
+ }
}
// resendSegment resends the first unacknowledged segment.
@@ -1154,6 +1157,13 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
+ dsackBlock := rcvdSeg.parsedOptions.SACKBlocks[0]
+ numDSACK := uint64(dsackBlock.End-dsackBlock.Start) / uint64(s.MaxPayloadSize)
+ // numDSACK can be zero when DSACK is sent for subsegments.
+ if numDSACK < 1 {
+ numDSACK = 1
+ }
+ s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.IncrementBy(numDSACK)
s.rc.setDSACKSeen(true)
idx = 1
n--
@@ -1335,10 +1345,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// some new data, i.e., only if it advances the left edge of
// the send window.
if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
- // TSVal/Ecr values sent by Netstack are at a millisecond
- // granularity.
- elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
- s.updateRTO(elapsed)
+ s.updateRTO(s.ep.elapsed(s.ep.stack.Clock().NowMonotonic(), rcvdSeg.parsedOptions.TSEcr))
}
if s.shouldSchedulePTO() {
@@ -1408,9 +1415,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ackLeft -= datalen
}
- // Update the send buffer usage and notify potential waiters.
- s.ep.updateSndBufferUsage(int(acked))
-
// Clear SACK information for all acked data.
s.ep.scoreboard.Delete(s.SndUna)
@@ -1430,6 +1434,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
// get an ack that cover previously sent data.
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index ced3a9c58..84fb1c416 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -16,6 +16,7 @@
// iterations taking long enough that the retransmit timer can kick in causing
// the congestion window measurements to fail due to extra packets etc.
//
+//go:build !race
// +build !race
package tcp_test
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index d6cf786a1..c35db7c95 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -33,12 +33,11 @@ const (
tsOptionSize = 12
maxTCPOptionSize = 40
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
- latency = 5 * time.Millisecond
)
-func setStackRACKPermitted(t *testing.T, c *context.Context) {
+func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) {
t.Helper()
- opt := tcpip.TCPRACKLossDetection
+ opt := tcpip.TCPRecovery(recovery)
if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err)
}
@@ -70,7 +69,6 @@ func TestRACKUpdate(t *testing.T) {
close(probeDone)
})
setStackSACKPermitted(t, c, true)
- setStackRACKPermitted(t, c)
createConnectedWithSACKAndTS(c)
data := make([]byte, maxPayload)
@@ -129,7 +127,6 @@ func TestRACKDetectReorder(t *testing.T) {
close(probeDone)
})
setStackSACKPermitted(t, c, true)
- setStackRACKPermitted(t, c)
createConnectedWithSACKAndTS(c)
data := make([]byte, ackNumToVerify*maxPayload)
for i := range data {
@@ -162,10 +159,13 @@ func TestRACKDetectReorder(t *testing.T) {
func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte {
setStackSACKPermitted(t, c, true)
- if enableRACK {
- setStackRACKPermitted(t, c)
+ if !enableRACK {
+ setStackTCPRecovery(t, c, 0)
}
- createConnectedWithSACKAndTS(c)
+ // The delay should be below initial RTO (1s) otherwise retransimission
+ // will start. Choose a relatively large value so that estimated RTT
+ // keeps high even after a few rounds of undelayed RTT samples.
+ c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */)
data := make([]byte, numPackets*maxPayload)
for i := range data {
@@ -183,9 +183,6 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
for i := 0; i < numPackets; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- // This delay is added to increase RTT as low RTT can cause TLP
- // before sending ACK.
- time.Sleep(latency)
}
return data
@@ -542,6 +539,28 @@ func TestRACKDetectDSACK(t *testing.T) {
case invalidDSACKDetected:
t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
}
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ // Check DSACK was received for one segment.
+ {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
}
// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of
@@ -682,6 +701,28 @@ func TestRACKDetectDSACKSingleDup(t *testing.T) {
case invalidDSACKDetected:
t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
}
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ // Check DSACK was received for a subsegment.
+ {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
}
// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous
@@ -998,7 +1039,6 @@ func TestRACKWithWindowFull(t *testing.T) {
defer c.Cleanup()
setStackSACKPermitted(t, c, true)
- setStackRACKPermitted(t, c)
createConnectedWithSACKAndTS(c)
seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 20c9761f2..6255355bb 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -35,13 +35,13 @@ import (
// SACKPermitted option enabled if the stack in the context has the SACK support
// enabled.
func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
}
// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
// option enabled if the stack in the context has SACK and TS enabled.
func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
}
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
@@ -61,6 +61,7 @@ func TestSackPermittedConnect(t *testing.T) {
defer c.Cleanup()
setStackSACKPermitted(t, c, sackEnabled)
+ setStackTCPRecovery(t, c, 0)
rep := createConnectedWithSACKPermittedOption(c)
data := []byte{1, 2, 3}
@@ -105,8 +106,9 @@ func TestSackDisabledConnect(t *testing.T) {
defer c.Cleanup()
setStackSACKPermitted(t, c, sackEnabled)
+ setStackTCPRecovery(t, c, 0)
- rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
data := []byte{1, 2, 3}
@@ -166,8 +168,9 @@ func TestSackPermittedAccept(t *testing.T) {
}
}
setStackSACKPermitted(t, c, sackEnabled)
+ setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
// Now verify no SACK blocks are
// received when sack is disabled.
data := []byte{1, 2, 3}
@@ -239,8 +242,9 @@ func TestSackDisabledAccept(t *testing.T) {
}
setStackSACKPermitted(t, c, sackEnabled)
+ setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now verify no SACK blocks are
// received when sack is disabled.
@@ -386,6 +390,7 @@ func TestSACKRecovery(t *testing.T) {
log.Printf("state: %+v\n", s)
})
setStackSACKPermitted(t, c, true)
+ setStackTCPRecovery(t, c, 0)
createConnectedWithSACKAndTS(c)
const iterations = 3
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9bbe9bc3e..58817371e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -1381,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) {
if err := s.CreateNIC(id, ep); err != nil {
t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
- t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: id},
@@ -2127,6 +2132,214 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+func TestSmallReceiveBufferReadiness(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+
+ const nicID = 1
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
+ }
+
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ t.Fatalf("tcpip.NewSubnet failed: %s", err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ listenerEntry, listenerCh := waiter.NewChannelEntry(nil)
+ var listenerWQ waiter.Queue
+ listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer listener.Close()
+ listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents)
+ defer listenerWQ.EventUnregister(&listenerEntry)
+
+ if err := listener.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := listener.Listen(1); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ localAddress, err := listener.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress failed: %s", err)
+ }
+
+ for i := 8; i > 0; i /= 2 {
+ size := int64(i << 10)
+ t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) {
+ var clientWQ waiter.Queue
+ client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer client.Close()
+ switch err := client.Connect(localAddress).(type) {
+ case nil:
+ t.Fatal("Connect returned nil error")
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect failed: %s", err)
+ }
+
+ <-listenerCh
+ server, serverWQ, err := listener.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+ defer server.Close()
+
+ client.SocketOptions().SetReceiveBufferSize(size, true)
+ // Send buffer size doesn't seem to affect this test.
+ // server.SocketOptions().SetSendBufferSize(size, true)
+
+ clientEntry, clientCh := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents)
+ defer clientWQ.EventUnregister(&clientEntry)
+
+ serverEntry, serverCh := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverEntry, waiter.WritableEvents)
+ defer serverWQ.EventUnregister(&serverEntry)
+
+ var total int64
+ for {
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ continue
+ case *tcpip.ErrWouldBlock:
+ select {
+ case <-serverCh:
+ continue
+ case <-time.After(100 * time.Millisecond):
+ // Well and truly full.
+ t.Logf("send and receive queues are full")
+ }
+ default:
+ t.Fatalf("Write failed: %s", err)
+ }
+ break
+ }
+ t.Logf("wrote %d bytes in total", total)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ if err := func() error {
+ var total int64
+ defer t.Logf("wrote %d bytes in total", total)
+ for r.Len() != 0 {
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on server")
+ select {
+ case <-serverCh:
+ case <-time.After(time.Second):
+ if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 {
+ t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("server.Write failed: %s", err)
+ }
+ }
+ if err := server.Shutdown(tcpip.ShutdownWrite); err != nil {
+ return fmt.Errorf("server.Shutdown failed: %s", err)
+ }
+ t.Logf("server end shutdown done")
+ return nil
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ if err := func() error {
+ total := 0
+ defer t.Logf("read %d bytes in total", total)
+ for {
+ switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
+ case nil:
+ t.Logf("read %d bytes", res.Count)
+ total += res.Count
+ t.Logf("read total %d bytes till now", total)
+ case *tcpip.ErrClosedForReceive:
+ return nil
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on client")
+ select {
+ case <-clientCh:
+ case <-time.After(time.Second):
+ if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 {
+ return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("client.Write failed: %s", err)
+ }
+ }
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+ })
+ }
+}
+
// Test the stack receive window advertisement on receiving segments smaller than
// segment overhead. It tests for the right edge of the window to not grow when
// the endpoint is not being read from.
@@ -2143,11 +2356,11 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
}
- c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true)
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */)
// Keep the payload size < segment overhead and such that it is a multiple
// of the window scaled value. This enables the test to perform equality
@@ -2267,7 +2480,7 @@ func TestNoWindowShrinking(t *testing.T) {
initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
// Now shrink the receive buffer to half its original size.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true)
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */)
data := generateRandomPayload(t, rcvBufSize)
// Send a payload of half the size of rcvBufSize.
@@ -2523,7 +2736,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
+ ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -2535,7 +2748,7 @@ func TestScaledWindowAccept(t *testing.T) {
// Do 3-way handshake.
// wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
- c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -2595,7 +2808,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
+ ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -3188,7 +3401,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
+ ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -3327,7 +3540,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 3
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
@@ -3451,17 +3664,13 @@ loop:
for {
switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
case *tcpip.ErrWouldBlock:
- select {
- case <-ch:
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
- }
- break loop
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for reset to arrive")
+ <-ch
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
}
+ break loop
case *tcpip.ErrConnectionReset:
break loop
default:
@@ -3472,14 +3681,27 @@ loop:
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
- if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
+
+ checkValid := func() []error {
+ var errors []error
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got))
+ }
+ return errors
}
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
}
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ for _, err := range checkValid() {
+ t.Error(err)
}
}
@@ -3523,6 +3745,12 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
+ // Wait for the connection to timeout after MaxRetries retransmits.
+ initRTO := time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -3545,8 +3773,6 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
),
)
}
- // Wait for the connection to timeout after MaxRetries retransmits.
- initRTO := 1 * time.Second
select {
case <-notifyCh:
case <-time.After((2 << numRetries) * initRTO):
@@ -3581,9 +3807,13 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- opt := tcpip.TCPMaxRTOOption(rto)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ minRTOOpt := tcpip.TCPMinRTOOption(rto / 2)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
+ maxRTOOpt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err)
}
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3609,12 +3839,44 @@ func TestMaxRTO(t *testing.T) {
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
- if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
- t.Errorf("Retransmit interval not capped to MaxRTO.\n")
+ if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() {
+ t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed)
}
}
}
+// TestZeroSizedWriteRetransmit tests that a zero sized write should not
+// result in a panic on an RTO as no segment should have been queued for
+// a zero sized write.
+func TestZeroSizedWriteRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ var r bytes.Reader
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ // Now do a non-zero sized write to trigger actual sending of data.
+ r.Reset(make([]byte, 1))
+ _, err = c.EP.Write(&r, tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ // Do not ACK the packet and expect an original transmit and a
+ // retransmit. This should not cause a panic.
+ for i := 0; i < 2; i++ {
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
+ ),
+ )
+ }
+}
+
// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
// unique on retransmits.
func TestRetransmitIPv4IDUniqueness(t *testing.T) {
@@ -3629,6 +3891,10 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ minRTOOpt := tcpip.TCPMinRTOOption(time.Second)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
// Disabling PMTU discovery causes all packets sent from this socket to
@@ -4628,52 +4894,6 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
}
-func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- // Check the default values.
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer ep.Close()
-
- // Change the min/max values for send/receive
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- {
- opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- // Set values below the min/2.
- ep.SocketOptions().SetReceiveBufferSize(99, true)
- checkRecvBufferSize(t, ep, 200)
-
- ep.SocketOptions().SetSendBufferSize(149, true)
-
- checkSendBufferSize(t, ep, 300)
-
- // Set values above the max.
- ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true)
- // Values above max are capped at max and then doubled.
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
-
- ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true)
- // Values above max are capped at max and then doubled.
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
-}
-
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
@@ -4741,13 +4961,17 @@ func makeStack() (*stack.Stack, tcpip.Error) {
}
for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- address tcpip.Address
+ number tcpip.NetworkProtocolNumber
+ addrWithPrefix tcpip.AddressWithPrefix
}{
- {ipv4.ProtocolNumber, context.StackAddr},
- {ipv6.ProtocolNumber, context.StackV6Addr},
+ {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
+ {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
} {
- if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ct.number,
+ AddressWithPrefix: ct.addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
return nil, err
}
}
@@ -4951,7 +5175,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
}
for i := start; i <= end; i++ {
- if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
+ if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
@@ -6068,6 +6292,11 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// complete the connection to test that the large SEQ num
// did not change the state from SYN-RCVD.
+ // Get setup to be notified about connection establishment.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.ReadableEvents)
+ defer c.WQ.EventUnregister(&we)
+
// Send ACK to move to ESTABLISHED state.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -6078,32 +6307,12 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
+ <-ch
newEP, _, err := c.EP.Accept(nil)
- switch err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- default:
+ if err != nil {
t.Fatalf("Accept failed: %s", err)
}
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
var r strings.Reader
@@ -6209,12 +6418,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
RcvWnd: 30000,
})
- time.Sleep(50 * time.Millisecond)
- if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
+ checkValid := func() []error {
+ var errors []error
+ if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want))
+ }
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want))
+ }
+ return errors
}
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
+ }
+ for _, err := range checkValid() {
+ t.Error(err)
+ }
+ if t.Failed() {
+ t.FailNow()
}
we, ch := waiter.NewChannelEntry(nil)
@@ -6225,15 +6448,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
_, _, err = c.EP.Accept(nil)
if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
+ <-ch
+ _, _, err = c.EP.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
}
}
}
@@ -6315,7 +6533,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -6396,7 +6614,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// maximum buffer size defined above.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
// NOTE: The timestamp values in the sent packets are meaningless to the
// peer so we just increment the timestamp value by 1 every batch as we
@@ -6526,7 +6744,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
tsVal := rawEP.TSVal
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
@@ -7441,6 +7659,11 @@ func TestTCPUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ initRTO := 1 * time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -7451,7 +7674,6 @@ func TestTCPUserTimeout(t *testing.T) {
// Ensure that on the next retransmit timer fire, the user timeout has
// expired.
- initRTO := 1 * time.Second
userTimeout := initRTO / 2
v := tcpip.TCPUserTimeoutOption(userTimeout)
if err := c.EP.SetSockOpt(&v); err != nil {
@@ -7483,7 +7705,7 @@ func TestTCPUserTimeout(t *testing.T) {
select {
case <-notifyCh:
case <-time.After(2 * initRTO):
- t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout)
+ t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout)
}
// No packet should be received as the connection should be silently
@@ -7717,7 +7939,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true)
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
@@ -7965,6 +8187,151 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
}
+func TestHandshakeRTT(t *testing.T) {
+ type testCase struct {
+ connect bool
+ tsEnabled bool
+ useCookie bool
+ retrans bool
+ delay time.Duration
+ wantRTT time.Duration
+ }
+ var testCases []testCase
+ for _, connect := range []bool{false, true} {
+ for _, tsEnabled := range []bool{false, true} {
+ for _, useCookie := range []bool{false, true} {
+ for _, retrans := range []bool{false, true} {
+ if connect && useCookie {
+ continue
+ }
+ delay := 800 * time.Millisecond
+ if retrans {
+ delay = 1200 * time.Millisecond
+ }
+ wantRTT := delay
+ // If syncookie is enabled, sample RTT only when TS option is enabled.
+ if !retrans && useCookie && !tsEnabled {
+ wantRTT = 0
+ }
+ // If retransmitted, sample RTT only when TS option is enabled.
+ if retrans && !tsEnabled {
+ wantRTT = 0
+ }
+ testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT})
+ }
+ }
+ }
+ }
+ for _, tt := range testCases {
+ tt := tt
+ t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) {
+ t.Parallel()
+ c := context.New(t, defaultMTU)
+ if tt.useCookie {
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ }
+ synOpts := header.TCPSynOptions{}
+ if tt.tsEnabled {
+ synOpts.TS = true
+ synOpts.TSVal = 42
+ }
+ if tt.connect {
+ c.CreateConnectedWithOptions(synOpts, tt.delay)
+ } else {
+ synOpts.MSS = defaultIPv4MSS
+ synOpts.WS = -1
+ c.AcceptWithOptions(-1, synOpts, tt.delay)
+ }
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+ if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT {
+ t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT)
+ }
+ if info.RTTVar != 0 && tt.wantRTT == 0 {
+ t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar)
+ }
+ if info.RTTVar == 0 && tt.wantRTT != 0 {
+ t.Fatalf("got info.RTTVar=0, expect non zero")
+ }
+ })
+ }
+}
+
+func TestSetRTO(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ for _, tt := range []struct {
+ name string
+ RTO time.Duration
+ minRTO time.Duration
+ maxRTO time.Duration
+ err tcpip.Error
+ }{
+ {
+ name: "invalid minRTO",
+ minRTO: maxRTO + time.Second,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "invalid maxRTO",
+ maxRTO: minRTO - time.Millisecond,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "valid minRTO",
+ minRTO: maxRTO - time.Second,
+ },
+ {
+ name: "valid maxRTO",
+ maxRTO: minRTO + time.Millisecond,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ var opt tcpip.SettableTransportProtocolOption
+ if tt.minRTO > 0 {
+ min := tcpip.TCPMinRTOOption(tt.minRTO)
+ opt = &min
+ }
+ if tt.maxRTO > 0 {
+ max := tcpip.TCPMaxRTOOption(tt.maxRTO)
+ opt = &max
+ }
+ err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt)
+ if got, want := err, tt.err; got != want {
+ t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want)
+ }
+ if tt.err == nil {
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ if tt.minRTO > 0 && tt.minRTO != minRTO {
+ t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO)
+ }
+ if tt.maxRTO > 0 && tt.maxRTO != maxRTO {
+ t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO)
+ }
+ }
+ })
+ }
+}
+
+func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) {
+ t.Helper()
+ var minOpt tcpip.TCPMinRTOOption
+ var maxOpt tcpip.TCPMaxRTOOption
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err)
+ }
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err)
+ }
+ return time.Duration(minOpt), time.Duration(maxOpt)
+}
+
// generateRandomPayload generates a random byte slice of the specified length
// causing a fatal test failure if it is unable to do so.
func generateRandomPayload(t *testing.T, n int) []byte {
@@ -7975,3 +8342,192 @@ func generateRandomPayload(t *testing.T, n int) []byte {
}
return buf
}
+
+func TestSendBufferTuning(t *testing.T) {
+ const maxPayload = 536
+ const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
+ const packetOverheadFactor = 2
+
+ testCases := []struct {
+ name string
+ autoTuningDisabled bool
+ }{
+ {"autoTuningDisabled", true},
+ {"autoTuningEnabled", false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Set the stack option for send buffer size.
+ const defaultSndBufSz = maxPayload * tcp.InitialCwnd
+ const maxSndBufSz = defaultSndBufSz * 10
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz}
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+ }
+
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
+
+ oldSz := c.EP.SocketOptions().GetSendBufferSize()
+ if oldSz != defaultSndBufSz {
+ t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz)
+ }
+
+ if tc.autoTuningDisabled {
+ c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */)
+ }
+
+ data := make([]byte, maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ w, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&w, waiter.WritableEvents)
+ defer c.WQ.EventUnregister(&w)
+
+ bytesRead := 0
+ for {
+ // Packets will be sent till the send buffer
+ // size is reached.
+ var r bytes.Reader
+ r.Reset(data[bytesRead : bytesRead+maxPayload])
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ break
+ }
+
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0)
+ bytesRead += maxPayload
+ data = append(data, data...)
+ }
+
+ // Send an ACK and wait for connection to become writable again.
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
+ select {
+ case <-ch:
+ if err := c.EP.LastError(); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+
+ outSz := int64(defaultSndBufSz)
+ if !tc.autoTuningDisabled {
+ // Calculate the new auto tuned send buffer.
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+ outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload
+ }
+
+ if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz {
+ t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz)
+ }
+ })
+ }
+}
+
+func TestTimestampSynCookies(t *testing.T) {
+ clock := faketime.NewManualClock()
+ tsNow := func() uint32 {
+ return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds())
+ }
+ // Advance the clock so that NowMonotonic is non-zero.
+ clock.Advance(time.Second)
+ c := context.NewWithOpts(t, context.Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: defaultMTU,
+ Clock: clock,
+ })
+ defer c.Cleanup()
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(42, 0, tcpOpts[2:])
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss,
+ TCPOpts: tcpOpts[:],
+ })
+ // Get the TSVal of SYN-ACK.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ initialTSVal := tcpHdr.ParsedOptions().TSVal
+ // derive the tsOffset.
+ tsOffset := initialTSVal - tsNow()
+
+ header.EncodeTSOption(420, initialTSVal, tcpOpts[2:])
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ TCPOpts: tcpOpts[:],
+ })
+ c.EP, _, err = ep.Accept(nil)
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ defer wq.EventUnregister(&we)
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ } else if err != nil {
+ t.Fatalf("failed to accept: %s", err)
+ }
+
+ // Advance the clock again so that we expect the next TSVal to change.
+ clock.Advance(time.Second)
+ data := []byte{1, 2, 3}
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // The endpoint should have a correct TSOffset so that the received TSVal
+ // should match our expectation.
+ if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want {
+ t.Fatalf("got TSVal = %d, want %d", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 1deb1fe4d..65925daa5 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -32,7 +32,7 @@ import (
// createConnectedWithTimestampOption creates and connects c.ep with the
// timestamp option enabled.
func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1})
}
// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
@@ -131,7 +131,7 @@ func TestTimeStampDisabledConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
@@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
@@ -209,7 +209,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
}
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 53efecc5a..88bb99354 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -122,6 +122,9 @@ type Options struct {
// MTU indicates the maximum transmission unit on the link layer.
MTU uint32
+
+ // Clock that is used by Stack.
+ Clock tcpip.Clock
}
// Context provides an initialized Network stack and a link layer endpoint
@@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
stackOpts := stack.Options{
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: opts.Clock,
}
if opts.EnableV4 {
stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
@@ -239,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: StackAddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
@@ -253,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: StackV6AddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
@@ -757,7 +761,7 @@ func (c *Context) Create(epRcvBuf int) {
}
if epRcvBuf != -1 {
- c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */)
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf)*2, true /* notify */)
}
}
@@ -879,13 +883,21 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
)
}
+// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions
+// without delay.
+func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint {
+ return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */)
+}
+
// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
// options enabled and returns a RawEndpoint which represents the other end of
-// the connection.
+// the connection. It delays before a SYNACK is sent. This makes c.EP have a
+// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps
+// reduce flakiness.
//
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
@@ -911,18 +923,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// TS value.
mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- ),
+ synChecker := checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{
+ MSS: mss,
+ TS: true,
+ WS: int(c.WindowScale),
+ SACKPermitted: c.SACKEnabled(),
+ }),
)
+ checker.IPv4(c.t, b, synChecker)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
@@ -948,6 +959,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
iss := seqnum.Value(TestInitialSequenceNumber)
+ if delay > 0 {
+ // Sleep so that RTT is increased.
+ time.Sleep(delay)
+ }
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -959,7 +974,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
})
// Read ACK.
- ackPacket := c.GetPacket()
+ var ackPacket []byte
+ // Ignore retransimitted SYN packets.
+ for {
+ packet := c.GetPacket()
+ if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 {
+ checker.IPv4(c.t, packet, synChecker)
+ } else {
+ ackPacket = packet
+ break
+ }
+ }
// Verify TCP header fields.
tcpCheckers := []checker.TransportChecker{
@@ -1016,13 +1041,19 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
}
}
-// AcceptWithOptions initializes a listening endpoint and connects to it with the
-// provided options enabled. It also verifies that the SYN-ACK has the expected
-// values for the provided options.
+// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay.
+func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */)
+}
+
+// AcceptWithOptions initializes a listening endpoint and connects to it with
+// the provided options enabled. It delays before the final ACK of the 3WHS is
+// sent. It also verifies that the SYN-ACK has the expected values for the
+// provided options.
//
// The function returns a RawEndpoint representing the other end of the accepted
// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
@@ -1045,7 +1076,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1077,13 +1108,14 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
// PassiveConnectWithOptions.
func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
+ c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */)
}
// PassiveConnectWithOptions initiates a new connection (with the specified TCP
// options enabled) to the port on which the Context.ep is listening for new
// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options.
+// the enabled options. The final ACK of the handshake is delayed by specified
+// duration.
//
// NOTE: MSS is not a negotiated option and it can be asymmetric
// in each direction. This function uses the maxPayload to set the MSS to be
@@ -1093,7 +1125,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
@@ -1180,7 +1212,10 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
ackHeaders.TCPOpts = opts[:]
}
- // Send ACK.
+ // Send ACK, delay if needed.
+ if delay > 0 {
+ time.Sleep(delay)
+ }
c.SendPacket(nil, ackHeaders)
c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go
new file mode 100644
index 000000000..4c2ae87f4
--- /dev/null
+++ b/pkg/tcpip/transport/transport.go
@@ -0,0 +1,16 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport supports transport protocols.
+package transport
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index cdc344ab7..5cc7a2886 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -35,6 +35,8 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index def9d7186..4255457f9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,8 @@
package udp
import (
+ "fmt"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -25,12 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
// +stateify savable
type udpPacket struct {
udpPacketEntry
+ netProto tcpip.NetworkProtocolNumber
senderAddress tcpip.FullAddress
destinationAddress tcpip.FullAddress
packetInfo tcpip.IPPacketInfo
@@ -40,36 +43,6 @@ type udpPacket struct {
tos uint8
}
-// EndpointState represents the state of a UDP endpoint.
-type EndpointState tcpip.EndpointState
-
-// Endpoint states. Note that are represented in a netstack-specific manner and
-// may not be meaningful externally. Specifically, they need to be translated to
-// Linux's representation for these states if presented to userspace.
-const (
- _ EndpointState = iota
- StateInitial
- StateBound
- StateConnected
- StateClosed
-)
-
-// String implements fmt.Stringer.
-func (s EndpointState) String() string {
- switch s {
- case StateInitial:
- return "INITIAL"
- case StateBound:
- return "BOUND"
- case StateConnected:
- return "CONNECTING"
- case StateClosed:
- return "CLOSED"
- default:
- return "UNKNOWN"
- }
-}
-
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -79,7 +52,6 @@ func (s EndpointState) String() string {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
@@ -87,6 +59,10 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -96,37 +72,19 @@ type endpoint struct {
rcvBufSize int
rcvClosed bool
- // The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- // state must be read/set using the EndpointState()/setEndpointState()
- // methods.
- state uint32
- route *stack.Route `state:"manual"`
- dstPort uint16
- ttl uint8
- multicastTTL uint8
- multicastAddr tcpip.Address
- multicastNICID tcpip.NICID
- portFlags ports.Flags
-
lastErrorMu sync.Mutex `state:"nosave"`
lastError tcpip.Error
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
- // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
- // applied while sending packets. Defaults to 0 as on Linux.
- sendTOS uint8
-
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
-
- // multicastMemberships that need to be remvoed when the endpoint is
- // closed. Protected by the mu mutex.
- multicastMemberships map[multicastMembership]struct{}
+ readShutdown bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -136,55 +94,25 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
-}
-// +stateify savable
-type multicastMembership struct {
- nicID tcpip.NICID
- multicastAddr tcpip.Address
+ localPort uint16
+ remotePort uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.UDPProtocolNumber,
- },
+ stack: s,
waiterQueue: waiterQueue,
- // RFC 1075 section 5.4 recommends a TTL of 1 for membership
- // requests.
- //
- // RFC 5135 4.2.1 appears to assume that IGMP messages have a
- // TTL of 1.
- //
- // RFC 5135 Appendix A defines TTL=1: A multicast source that
- // wants its traffic to not traverse a router (e.g., leave a
- // home network) may find it useful to send traffic with IP
- // TTL=1.
- //
- // Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastMemberships: make(map[multicastMembership]struct{}),
- state: uint32(StateInitial),
- uniqueID: s.UniqueID(),
+ uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -200,20 +128,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
-// setEndpointState updates the state of the endpoint to state atomically. This
-// method is unexported as the only place we should update the state is in this
-// package but we allow the state to be read freely without holding e.mu.
-//
-// Precondition: e.mu must be held to call this method.
-func (e *endpoint) setEndpointState(state EndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
-// EndpointState() returns the current state of the endpoint.
-func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32(&e.state))
-}
-
// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -244,16 +158,22 @@ func (e *endpoint) Abort() {
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.EndpointState() {
- case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ e.mu.Unlock()
+ return
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -261,13 +181,10 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(portRes)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- for mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
- }
- e.multicastMemberships = make(map[multicastMembership]struct{})
-
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -278,14 +195,9 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- // Update the state.
- e.setEndpointState(StateClosed)
-
+ e.net.Shutdown()
+ e.net.Close()
+ e.readShutdown = true
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
@@ -324,14 +236,21 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
HasTimestamp: true,
Timestamp: p.receivedAt.UnixNano(),
}
- if e.ops.GetReceiveTOS() {
- cm.HasTOS = true
- cm.TOS = p.tos
- }
- if e.ops.GetReceiveTClass() {
- cm.HasTClass = true
- // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
- cm.TClass = uint32(p.tos)
+
+ switch p.netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceiveTOS() {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetReceiveTClass() {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
}
if e.ops.GetReceivePacketInfo() {
cm.HasIPPacketInfo = true
@@ -359,18 +278,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
return res, nil
}
-// prepareForWrite prepares the endpoint for sending data. In particular, it
-// binds it if it's still in the initial state. To do so, it must first
+// prepareForWriteInner prepares the endpoint for sending data. In particular,
+// it binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.EndpointState() {
- case StateInitial:
- case StateConnected:
+// +checklocks:e.mu
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
- case StateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -380,14 +300,12 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -399,33 +317,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
return true, nil
}
-// connectRoute establishes a route to the specified interface or the
-// configured multicast interface if no interface is specified and the
-// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.ID.LocalAddress
- if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
- // A packet can only originate from a unicast address (i.e., an interface).
- localAddr = ""
- }
-
- if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicID == 0 {
- nicID = e.multicastNICID
- }
- if localAddr == "" && nicID == 0 {
- localAddr = e.multicastAddr
- }
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
- if err != nil {
- return nil, 0, err
- }
- return r, nicID, nil
-}
-
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
@@ -449,37 +340,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
- lockReleased := false
- defer func() {
- if lockReleased {
- return
- }
- e.mu.RUnlock()
- }()
-
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
- }
+ defer e.mu.RUnlock()
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
if !retry {
@@ -487,50 +356,29 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
}
- route := e.route
- dstPort := e.dstPort
- if to != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := to.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
- if to.Port == 0 {
+ dst, connected := e.net.GetRemoteAddress()
+ dst.Port = e.remotePort
+ if opts.To != nil {
+ if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
- return 0, &tcpip.ErrInvalidEndpointState{}
- }
-
- dst, netProto, err := e.checkV4MappedLocked(*to)
- if err != nil {
- return 0, err
+ return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- r, _, err := e.connectRoute(nicID, dst, netProto)
- if err != nil {
- return 0, err
- }
- defer r.Release()
-
- route = r
- dstPort = dst.Port
+ dst = *opts.To
+ } else if !connected {
+ return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
}
- if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return 0, &tcpip.ErrBroadcastDisabled{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return udpPacketInfo{}, err
}
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
- return 0, &tcpip.ErrBadBuffer{}
+ ctx.Release()
+ return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
@@ -538,35 +386,25 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
if so.GetRecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
- route.NetProto(),
+ e.net.NetProto(),
header.UDPMaximumPacketSize,
- tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress(),
- Port: dstPort,
- },
+ dst,
v,
)
}
- return 0, &tcpip.ErrMessageTooLong{}
+ ctx.Release()
+ return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
- ttl := e.ttl
- useDefaultTTL := ttl == 0
-
- if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
- ttl = e.multicastTTL
- // Multicast allows a 0 TTL.
- useDefaultTTL = false
- }
-
- localPort := e.ID.LocalPort
- sendTOS := e.sendTOS
- owner := e.owner
- noChecksum := e.SocketOptions().GetNoChecksum()
- lockReleased = true
- e.mu.RUnlock()
+ return udpPacketInfo{
+ ctx: ctx,
+ data: v,
+ localPort: e.localPort,
+ remotePort: dst.Port,
+ }, nil
+}
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
// deadlock where the ICMP response handling ends up acquiring this endpoint's
@@ -577,10 +415,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
+
+ if err := e.LastError(); err != nil {
return 0, err
}
- return int64(len(v)), nil
+
+ udpInfo, err := e.prepareForWrite(p, opts)
+ if err != nil {
+ return 0, err
+ }
+ defer udpInfo.ctx.Release()
+
+ pktInfo := udpInfo.ctx.PacketInfo()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength),
+ Data: udpInfo.data.ToVectorisedView(),
+ })
+
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
+
+ length := uint16(pkt.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: udpInfo.localPort,
+ DstPort: udpInfo.remotePort,
+ Length: length,
+ })
+
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if pktInfo.RequiresTXTransportChecksum &&
+ (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
+ udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine(
+ header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
+ pkt.Data().AsRange().Checksum(),
+ )))
+ }
+ if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ e.stack.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
+ }
+
+ // Track count of packets sent.
+ e.stack.Stats().UDP.PacketsSent.Increment()
+ return int64(len(udpInfo.data)), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -599,36 +480,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.MTUDiscoverOption:
- // Return not supported if the value is not disabling path
- // MTU discovery.
- if v != tcpip.PMTUDiscoveryDont {
- return &tcpip.ErrNotSupported{}
- }
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- e.multicastTTL = uint8(v)
- e.mu.Unlock()
-
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- }
-
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
@@ -640,145 +492,12 @@ func (e *endpoint) HasNIC(id int32) bool {
// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
-
- fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
- if err != nil {
- return err
- }
- nic := v.NIC
- addr := fa.Addr
-
- if nic == 0 && addr == "" {
- e.multicastAddr = ""
- e.multicastNICID = 0
- break
- }
-
- if nic != 0 {
- if !e.stack.CheckNIC(nic) {
- return &tcpip.ErrBadLocalAddress{}
- }
- } else {
- nic = e.stack.CheckLocalAddress(0, netProto, addr)
- if nic == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
- }
-
- if e.BindNICID != 0 && e.BindNICID != nic {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- e.multicastNICID = nic
- e.multicastAddr = addr
-
- case *tcpip.AddMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
-
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToInsert]; ok {
- return &tcpip.ErrPortInUse{}
- }
-
- if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- e.multicastMemberships[memToInsert] = struct{}{}
-
- case *tcpip.RemoveMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToRemove]; !ok {
- return &tcpip.ErrBadLocalAddress{}
- }
-
- if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- delete(e.multicastMemberships, memToRemove)
-
- case *tcpip.SocketDetachFilterOption:
- return nil
- }
- return nil
+ return e.net.SetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
- case tcpip.IPv4TOSOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.MTUDiscoverOption:
- // The only supported setting is path MTU discovery disabled.
- return tcpip.PMTUDiscoveryDont, nil
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- v := int(e.multicastTTL)
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -789,92 +508,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.mu.Lock()
- v := int(e.ttl)
- e.mu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- *o = tcpip.MulticastInterfaceOption{
- NIC: e.multicastNICID,
- InterfaceAddr: e.multicastAddr,
- }
- e.mu.Unlock()
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
- return nil
-}
-
-// sendUDP sends a UDP segment via the provided network endpoint and under the
-// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
- Data: data,
- })
- pkt.Owner = owner
-
- // Initialize the UDP header.
- udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- pkt.TransportProtocolNumber = ProtocolNumber
-
- length := uint16(pkt.Size())
- udp.Encode(&header.UDPFields{
- SrcPort: localPort,
- DstPort: remotePort,
- Length: length,
- })
-
- // Set the checksum field unless TX checksum offload is enabled.
- // On IPv4, UDP checksum is optional, and a zero value indicates the
- // transmitter skipped the checksum generation (RFC768).
- // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.RequiresTXTransportChecksum() &&
- (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range data.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udp.SetChecksum(^udp.CalculateChecksum(xsum))
- }
-
- if useDefaultTTL {
- ttl = r.DefaultTTL()
- }
- if err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: ProtocolNumber,
- TTL: ttl,
- TOS: tos,
- }, pkt); err != nil {
- r.Stats().UDP.PacketSendErrors.Increment()
- return err
- }
-
- // Track count of packets sent.
- r.Stats().UDP.PacketsSent.Increment()
- return nil
+ return e.net.GetSockOpt(opt)
}
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
+// udpPacketInfo holds information needed to send a UDP packet.
+type udpPacketInfo struct {
+ ctx network.WriteContext
+ data buffer.View
+ localPort uint16
+ remotePort uint16
}
// Disconnect implements tcpip.Endpoint.
@@ -882,7 +531,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.EndpointState() != StateConnected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return nil
}
var (
@@ -895,26 +544,28 @@ func (e *endpoint) Disconnect() tcpip.Error {
boundPortFlags := e.boundPortFlags
// Exclude ephemerally bound endpoints.
- if e.BindNICID != 0 || e.ID.LocalAddress == "" {
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ if e.net.WasBound() {
var err tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.ID.LocalPort,
- LocalAddress: e.ID.LocalAddress,
+ LocalPort: info.ID.LocalPort,
+ LocalAddress: info.ID.LocalAddress,
}
id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
return err
}
- e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
- if e.ID.LocalPort != 0 {
+ if info.ID.LocalPort != 0 {
// Release the ephemeral port.
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: info.ID.LocalAddress,
+ Port: info.ID.LocalPort,
Flags: boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -922,15 +573,14 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
- e.setEndpointState(StateInitial)
}
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
- e.ID = id
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
e.boundBindToDevice = btd
- e.route.Release()
- e.route = nil
- e.dstPort = 0
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+
+ e.net.Disconnect()
return nil
}
@@ -940,88 +590,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- var localPort uint16
- switch e.EndpointState() {
- case StateInitial:
- case StateBound, StateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
-
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.localPort
+ nextID.RemotePort = addr.Port
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: e.ID.LocalAddress,
- LocalPort: localPort,
- RemotePort: addr.Port,
- RemoteAddress: r.RemoteAddress(),
- }
-
- if e.EndpointState() == StateInitial {
- id.LocalAddress = r.LocalAddress()
- }
+ oldPortFlags := e.boundPortFlags
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv4ProtocolNumber,
- header.IPv6ProtocolNumber,
+ nextID, btd, err := e.registerWithStack(netProtos, nextID)
+ if err != nil {
+ return err
}
- }
- oldPortFlags := e.boundPortFlags
+ // Remove the old registration.
+ if e.localPort != 0 {
+ previousID.LocalPort = e.localPort
+ previousID.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
+ }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = nextID.LocalPort
+ e.remotePort = nextID.RemotePort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- // Remove the old registration.
- if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
- }
-
- e.ID = id
- e.boundBindToDevice = btd
- if e.route != nil {
- // If the endpoint was already connected then make sure we release the
- // previous route.
- e.route.Release()
- }
- e.route = r
- e.dstPort = addr.Port
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- e.setEndpointState(StateConnected)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1036,15 +646,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // A socket in the bound state can still receive multicast messages,
- // so we need to notify waiters on shutdown.
- if state := e.EndpointState(); state != StateBound && state != StateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- e.shutdownFlags |= flags
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
+ }
if flags&tcpip.ShutdownRead != 0 {
+ e.readShutdown = true
+
e.rcvMu.Lock()
wasClosed := e.rcvClosed
e.rcvClosed = true
@@ -1070,7 +688,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- if e.ID.LocalPort == 0 {
+ if e.localPort == 0 {
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
@@ -1108,56 +726,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
+ if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
}
- }
- nicID := addr.NIC
- if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
- // A local unicast address was specified, verify that it's valid.
- nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicID == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: boundAddr,
+ }
+ id, btd, err := e.registerWithStack(netProtos, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = id.LocalPort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.boundBindToDevice = btd
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- // Mark endpoint as bound.
- e.setEndpointState(StateBound)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1172,9 +777,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
return err
}
- // Save the effective NICID generated by bindLocked.
- e.BindNICID = e.RegisterNICID
-
return nil
}
@@ -1183,16 +785,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.ID.LocalAddress
- if e.EndpointState() == StateConnected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.localPort
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -1200,15 +795,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.EndpointState() != StateConnected || e.dstPort == 0 {
+ addr, connected := e.net.GetRemoteAddress()
+ if !connected || e.remotePort == 0 {
return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ addr.Port = e.remotePort
+ return addr, nil
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -1303,6 +896,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// Push new packet into receive list and increment the buffer size.
packet := &udpPacket{
+ netProto: pkt.NetworkProtocolNumber,
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
@@ -1358,19 +952,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
payload = udp.Payload()
}
+ id := e.net.Info().ID
e.SocketOptions().QueueErr(&tcpip.SockError{
Err: err,
Cause: transErr,
Payload: payload,
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: id.RemoteAddress,
+ Port: e.remotePort,
},
Offender: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: e.localPort,
},
NetProto: pkt.NetworkProtocolNumber,
})
@@ -1385,7 +980,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// TODO(gvisor.dev/issues/5270): Handle all transport errors.
switch transErr.Kind() {
case stack.DestinationPortUnreachableTransportError:
- if e.EndpointState() == StateConnected {
+ if e.net.State() == transport.DatagramEndpointStateConnected {
e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
}
}
@@ -1393,16 +988,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
- return uint32(e.EndpointState())
+ return uint32(e.net.State())
}
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
- return &ret
+ defer e.mu.RUnlock()
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ return &info
}
// Stats returns a pointer to the endpoint stats.
@@ -1413,13 +1009,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
-func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
-}
-
// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// SocketOptions implements tcpip.Endpoint.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 1f638c3f6..2ff8b0482 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,12 +15,13 @@
package udp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -35,17 +36,11 @@ func (p *udpPacket) loadReceivedAt(nsec int64) {
// saveData saves udpPacket.data field.
func (p *udpPacket) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads udpPacket.data field.
func (p *udpPacket) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
@@ -66,50 +61,28 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.mu.Lock()
defer e.mu.Unlock()
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- for m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
-
- state := e.EndpointState()
- if state != StateBound && state != StateConnected {
- return
- }
-
- netProto := e.effectiveNetProtos[0]
- // Connect() and bindLocked() both assert
- //
- // netProto == header.IPv6ProtocolNumber
- //
- // before creating a multi-entry effectiveNetProtos.
- if len(e.effectiveNetProtos) > 1 {
- netProto = header.IPv6ProtocolNumber
- }
-
- var err tcpip.Error
- if state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ var err tcpip.Error
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
- // A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- // Our saved state had a port, but we don't actually have a
- // reservation. We need to remove the port from our state, but still
- // pass it to the reservation machinery.
- id := e.ID
- e.ID.LocalPort = 0
- e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
- if err != nil {
- panic(err)
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 7c357cb09..7238fc019 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
netHdr := r.pkt.Network()
- route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
- if err != nil {
+ if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil {
+ return nil, err
+ }
+
+ if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil {
return nil, err
}
- ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
- route.Release()
return nil, err
}
- ep.ID = r.id
- ep.route = route
- ep.dstPort = r.id.RemotePort
+ ep.localPort = r.id.LocalPort
+ ep.remotePort = r.id.RemotePort
ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
- ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
- ep.state = uint32(StateConnected)
-
ep.rcvMu.Lock()
ep.rcvReady = true
ep.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4008cacf2..2d15830a7 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -290,6 +290,7 @@ type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
+ nicID tcpip.NICID
ep tcpip.Endpoint
wq waiter.Queue
@@ -301,6 +302,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
}
func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
+ const nicID = 1
+
t.Helper()
options := stack.Options{
@@ -316,32 +319,41 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
+ if err := s.CreateNIC(nicID, wep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
{
Destination: header.IPv6EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
})
return &testContext{
t: t,
s: s,
+ nicID: nicID,
linkEP: ep,
}
}
@@ -1644,8 +1656,10 @@ func TestSetTTL(t *testing.T) {
}
}
+var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6}
+
func TestSetTOS(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ for _, flow := range v4PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1680,8 +1694,10 @@ func TestSetTOS(t *testing.T) {
}
}
+var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6}
+
func TestSetTClass(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ for _, flow := range v6PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1725,8 +1741,14 @@ func TestReceiveTosTClass(t *testing.T) {
name string
tests []testFlow
}{
- {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
- {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {
+ name: RcvTOSOpt,
+ tests: v4PacketFlows[:],
+ },
+ {
+ name: RcvTClassOpt,
+ tests: v6PacketFlows[:],
+ },
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1737,6 +1759,14 @@ func TestReceiveTosTClass(t *testing.T) {
c.createEndpointForFlow(flow)
name := testCase.name
+ if flow.isMulticast() {
+ netProto := flow.netProto()
+ addr := flow.getMcastAddr()
+ if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil {
+ c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err)
+ }
+ }
+
var optionGetter func() bool
var optionSetter func(bool)
switch name {
@@ -2482,8 +2512,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)