summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go24
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go112
-rw-r--r--pkg/tcpip/checker/checker.go13
-rw-r--r--pkg/tcpip/header/ipv4.go9
-rw-r--r--pkg/tcpip/header/parse/parse.go68
-rw-r--r--pkg/tcpip/link/pipe/pipe.go4
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go16
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD31
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go2
-rw-r--r--pkg/tcpip/link/sharedmem/queuepair.go199
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go30
-rw-r--r--pkg/tcpip/link/sharedmem/server_rx.go142
-rw-r--r--pkg/tcpip/link/sharedmem/server_tx.go175
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go230
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server.go344
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server_test.go220
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go114
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_unsafe.go33
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go32
-rw-r--r--pkg/tcpip/network/arp/arp.go1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/ip_test.go86
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go124
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go28
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go114
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go218
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go166
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go71
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go137
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go243
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go33
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go48
-rw-r--r--pkg/tcpip/network/multicast_group_test.go21
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go8
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go16
-rw-r--r--pkg/tcpip/socketops.go18
-rw-r--r--pkg/tcpip/stack/BUILD2
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go24
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go4
-rw-r--r--pkg/tcpip/stack/conntrack.go685
-rw-r--r--pkg/tcpip/stack/conntrack_test.go132
-rw-r--r--pkg/tcpip/stack/forwarding_test.go26
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go39
-rw-r--r--pkg/tcpip/stack/iptables.go216
-rw-r--r--pkg/tcpip/stack/iptables_targets.go201
-rw-r--r--pkg/tcpip/stack/iptables_types.go28
-rw-r--r--pkg/tcpip/stack/ndp_test.go107
-rw-r--r--pkg/tcpip/stack/nic.go43
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go69
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go38
-rw-r--r--pkg/tcpip/stack/registration.go29
-rw-r--r--pkg/tcpip/stack/stack.go57
-rw-r--r--pkg/tcpip/stack/stack_test.go642
-rw-r--r--pkg/tcpip/stack/tcp.go6
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go48
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go18
-rw-r--r--pkg/tcpip/stack/transport_test.go40
-rw-r--r--pkg/tcpip/tcpip.go54
-rw-r--r--pkg/tcpip/tcpip_state.go (renamed from pkg/tcpip/stack/iptables_state.go)25
-rw-r--r--pkg/tcpip/tests/integration/BUILD26
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go20
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go695
-rw-r--r--pkg/tcpip/tests/integration/istio_test.go365
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go24
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go47
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go20
-rw-r--r--pkg/tcpip/tests/integration/route_test.go38
-rw-r--r--pkg/tcpip/tests/utils/utils.go68
-rw-r--r--pkg/tcpip/transport/icmp/BUILD2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go438
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go35
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go8
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD1
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go177
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go33
-rw-r--r--pkg/tcpip/transport/internal/noop/BUILD14
-rw-r--r--pkg/tcpip/transport/internal/noop/endpoint.go172
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go120
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go18
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go44
-rw-r--r--pkg/tcpip/transport/raw/protocol.go16
-rw-r--r--pkg/tcpip/transport/tcp/BUILD13
-rw-r--r--pkg/tcpip/transport/tcp/accept.go346
-rw-r--r--pkg/tcpip/transport/tcp/connect.go9
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go115
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go127
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go15
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go255
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go110
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go27
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go123
103 files changed, 6742 insertions, 2517 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index dbe4506cc..b98de54c5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -25,6 +25,7 @@ go_library(
"stdclock.go",
"stdclock_state.go",
"tcpip.go",
+ "tcpip_state.go",
"timer.go",
],
visibility = ["//visibility:public"],
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 010e2e833..1f2bcaf65 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -19,6 +19,7 @@ import (
"bytes"
"context"
"errors"
+ "fmt"
"io"
"net"
"time"
@@ -471,9 +472,9 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCPConn connected to the specified address
-// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+// DialTCPWithBind creates a new TCPConn connected to the specified
+// remoteAddress with its local address bound to localAddr.
+func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -494,7 +495,14 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
default:
}
- err = ep.Connect(addr)
+ // Bind before connect if requested.
+ if localAddr != (tcpip.FullAddress{}) {
+ if err = ep.Bind(localAddr); err != nil {
+ return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
+ }
+ }
+
+ err = ep.Connect(remoteAddr)
if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
@@ -510,7 +518,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
return nil, &net.OpError{
Op: "connect",
Net: "tcp",
- Addr: fullToTCPAddr(addr),
+ Addr: fullToTCPAddr(remoteAddr),
Err: errors.New(err.String()),
}
}
@@ -518,6 +526,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
return NewTCPConn(&wq, ep), nil
}
+// DialContextTCP creates a new TCPConn connected to the specified address
+// with the option of adding cancellation and timeouts.
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+ return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network)
+}
+
// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
// net.Conn and net.PacketConn.
type UDPConn struct {
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 48b24692b..dcc9fff17 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
done := make(chan struct{})
@@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
done := make(chan struct{})
fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
@@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ return nil, nil, nil, fmt.Errorf("NewListener: %w", err)
}
c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ return nil, nil, nil, fmt.Errorf("DialTCP: %w", err)
}
c2, err = l.Accept()
if err != nil {
l.Close()
c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Accept: %w", err)
}
stop = func() {
@@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
if err := l.Close(); err != nil {
stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Close(): %w", err)
}
return c1, c2, stop, nil
@@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
@@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
time.Sleep(time.Second)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f34bf8dd..24c2c3e6b 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field
+// in ControlMessages.
+func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPv6PacketInfo {
+ t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo)
+ } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" {
+ t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
// field in ControlMessages.
func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index dcc549c7b..7baaf0d17 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -208,6 +208,15 @@ var IPv4EmptySubnet = func() tcpip.Subnet {
return subnet
}()
+// IPv4LoopbackSubnet is the loopback subnet for IPv4.
+var IPv4LoopbackSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x7f\x00\x00\x00"), tcpip.AddressMask("\xff\x00\x00\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPVersion returns the version of IP used in the given packet. It returns -1
// if the packet is not large enough to contain the version field.
func IPVersion(b []byte) int {
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 1c913b5e1..80a9ad6be 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -110,6 +110,16 @@ traverseExtensions:
switch extHdr := extHdr.(type) {
case header.IPv6FragmentExtHdr:
+ if extHdr.IsAtomic() {
+ // This fragment extension header indicates that this packet is an
+ // atomic fragment. An atomic fragment is a fragment that contains
+ // all the data required to reassemble a full packet. As per RFC 6946,
+ // atomic fragments must not interfere with "normal" fragmented traffic
+ // so we skip processing the fragment instead of feeding it through the
+ // reassembly process below.
+ continue
+ }
+
if fragID == 0 && fragOffset == 0 && !fragMore {
fragID = extHdr.ID()
fragOffset = extHdr.FragmentOffset()
@@ -175,3 +185,61 @@ func TCP(pkt *stack.PacketBuffer) bool {
pkt.TransportProtocolNumber = header.TCPProtocolNumber
return ok
}
+
+// ICMPv4 populates the packet buffer's transport header with an ICMPv4 header,
+// if present.
+//
+// Returns true if an ICMPv4 header was successfully parsed.
+func ICMPv4(pkt *stack.PacketBuffer) bool {
+ if _, ok := pkt.TransportHeader().Consume(header.ICMPv4MinimumSize); ok {
+ pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
+ return true
+ }
+ return false
+}
+
+// ICMPv6 populates the packet buffer's transport header with an ICMPv4 header,
+// if present.
+//
+// Returns true if an ICMPv6 header was successfully parsed.
+func ICMPv6(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize)
+ if !ok {
+ return false
+ }
+
+ h := header.ICMPv6(hdr)
+ switch h.Type() {
+ case header.ICMPv6RouterSolicit,
+ header.ICMPv6RouterAdvert,
+ header.ICMPv6NeighborSolicit,
+ header.ICMPv6NeighborAdvert,
+ header.ICMPv6RedirectMsg:
+ size := pkt.Data().Size()
+ if _, ok := pkt.TransportHeader().Consume(size); !ok {
+ panic(fmt.Sprintf("expected to consume the full data of size = %d bytes into transport header", size))
+ }
+ case header.ICMPv6MulticastListenerQuery,
+ header.ICMPv6MulticastListenerReport,
+ header.ICMPv6MulticastListenerDone:
+ size := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ if _, ok := pkt.TransportHeader().Consume(size); !ok {
+ return false
+ }
+ case header.ICMPv6DstUnreachable,
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6ParamProblem,
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoReply:
+ fallthrough
+ default:
+ if _, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize); !ok {
+ // Checked above if the packet buffer holds at least the minimum size for
+ // an ICMPv6 packet.
+ panic(fmt.Sprintf("expected to consume %d bytes", header.ICMPv6MinimumSize))
+ }
+ }
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ return true
+}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 3ed0aa3fe..c67ca98ea 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -123,4 +123,6 @@ func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber
}
// WriteRawPacket implements stack.LinkEndpoint.
-func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.WritePacket(stack.RouteInfo{}, 0, pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 87a0b9a62..e53789d92 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -152,10 +152,22 @@ type PollEvent struct {
// no data is available, it will block in a poll() syscall until the file
// descriptor becomes readable.
func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
+ n, err := BlockingReadUntranslated(fd, b)
+ if err != 0 {
+ return n, TranslateErrno(err)
+ }
+ return n, nil
+}
+
+// BlockingReadUntranslated reads from a file descriptor that is set up as
+// non-blocking. If no data is available, it will block in a poll() syscall
+// until the file descriptor becomes readable. It returns the raw unix.Errno
+// value returned by the underlying syscalls.
+func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
- return int(n), nil
+ return int(n), 0
}
event := PollEvent{
@@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
_, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
- return 0, TranslateErrno(e)
+ return 0, e
}
}
}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 4215ee852..af755473c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -5,19 +5,27 @@ package(licenses = ["notice"])
go_library(
name = "sharedmem",
srcs = [
+ "queuepair.go",
"rx.go",
+ "server_rx.go",
+ "server_tx.go",
"sharedmem.go",
+ "sharedmem_server.go",
"sharedmem_unsafe.go",
"tx.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/cleanup",
+ "//pkg/eventfd",
"//pkg/log",
+ "//pkg/memutil",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/pipe",
"//pkg/tcpip/link/sharedmem/queue",
"//pkg/tcpip/stack",
"@org_golang_x_sys//unix:go_default_library",
@@ -26,9 +34,7 @@ go_library(
go_test(
name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
+ srcs = ["sharedmem_test.go"],
library = ":sharedmem",
deps = [
"//pkg/sync",
@@ -41,3 +47,22 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "sharedmem_server_test",
+ size = "small",
+ srcs = ["sharedmem_server_test.go"],
+ deps = [
+ ":sharedmem",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
index 696e6c9e5..a78826ebc 100644
--- a/pkg/tcpip/link/sharedmem/queue/rx.go
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
}
r.tx.Flush()
-
return true
}
@@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
for {
outBufs := bufs
-
// Pull the next descriptor from the rx pipe.
b := r.rx.Pull()
if b == nil {
diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go
new file mode 100644
index 000000000..b12647fdd
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queuepair.go
@@ -0,0 +1,199 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
+)
+
+const (
+ // defaultQueueDataSize is the size of the shared memory data region that
+ // holds the scatter/gather buffers.
+ defaultQueueDataSize = 1 << 20 // 1MiB
+
+ // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors.
+ //
+ // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU)
+ // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data
+ // area. Which means the pipe needs to be big enough to hold 819
+ // descriptors.
+ //
+ // Each descriptor is approximately 8 (slot descriptor in pipe) +
+ // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is
+ // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.)
+ //
+ // Which means we need approximately 36*819 ~ 29 KiB to store all packet
+ // descriptors. We could go with a 32 KiB pipe but to give it some slack in
+ // how the upper layer may make use of the scatter gather buffers we double
+ // this to hold enough descriptors.
+ defaultQueuePipeSize = 64 << 10 // 64KiB
+
+ // defaultSharedDataSize is the size of the sharedData region used to
+ // enable/disable notifications.
+ defaultSharedDataSize = 4 << 10 // 4KiB
+)
+
+// A QueuePair represents a pair of TX/RX queues.
+type QueuePair struct {
+ // txCfg is the QueueConfig to be used for transmit queue.
+ txCfg QueueConfig
+
+ // rxCfg is the QueueConfig to be used for receive queue.
+ rxCfg QueueConfig
+}
+
+// NewQueuePair creates a shared memory QueuePair.
+func NewQueuePair() (*QueuePair, error) {
+ txCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to create tx queue: %s", err)
+ }
+
+ rxCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ closeFDs(txCfg)
+ return nil, fmt.Errorf("failed to create rx queue: %s", err)
+ }
+
+ return &QueuePair{
+ txCfg: txCfg,
+ rxCfg: rxCfg,
+ }, nil
+}
+
+// Close closes underlying tx/rx queue fds.
+func (q *QueuePair) Close() {
+ closeFDs(q.txCfg)
+ closeFDs(q.rxCfg)
+}
+
+// TXQueueConfig returns the QueueConfig for the receive queue.
+func (q *QueuePair) TXQueueConfig() QueueConfig {
+ return q.txCfg
+}
+
+// RXQueueConfig returns the QueueConfig for the transmit queue.
+func (q *QueuePair) RXQueueConfig() QueueConfig {
+ return q.rxCfg
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(s queueSizes) (QueueConfig, error) {
+ success := false
+ var eventFD eventfd.Eventfd
+ var dataFD, txPipeFD, rxPipeFD, sharedDataFD int
+ defer func() {
+ if success {
+ return
+ }
+ closeFDs(QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ })
+ }()
+ eventFD, err := eventfd.Create()
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err)
+ }
+ dataFD, err = createFile(s.dataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err)
+ }
+ txPipeFD, err = createFile(s.txPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err)
+ }
+ rxPipeFD, err = createFile(s.rxPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err)
+ }
+ sharedDataFD, err = createFile(s.sharedDataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err)
+ }
+ success = true
+ return QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ }, nil
+}
+
+func createFile(size int64, initQueue bool) (fd int, err error) {
+ const tmpDir = "/dev/shm/"
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ return -1, fmt.Errorf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ unix.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil {
+ return -1, fmt.Errorf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err = unix.Dup(int(f.Fd()))
+ if err != nil {
+ return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err)
+ }
+
+ if err := unix.Ftruncate(fd, size); err != nil {
+ unix.Close(fd)
+ return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err)
+ }
+
+ return fd, nil
+}
+
+func closeFDs(c QueueConfig) {
+ unix.Close(c.DataFD)
+ c.EventFD.Close()
+ unix.Close(c.TxPipeFD)
+ unix.Close(c.RxPipeFD)
+ unix.Close(c.SharedDataFD)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index e882a128c..87747dcc7 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -21,7 +21,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -30,7 +30,7 @@ type rx struct {
data []byte
sharedData []byte
q queue.Rx
- eventFD int
+ eventFD eventfd.Eventfd
}
// init initializes all state needed by the rx queue based on the information
@@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
// Duplicate the eventFD so that caller can close it but we can still
// use it.
- efd, err := unix.Dup(c.EventFD)
+ efd, err := c.EventFD.Dup()
if err != nil {
unix.Munmap(txPipe)
unix.Munmap(rxPipe)
@@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
return err
}
- // Set the eventfd as non-blocking.
- if err := unix.SetNonblock(efd, true); err != nil {
- unix.Munmap(txPipe)
- unix.Munmap(rxPipe)
- unix.Munmap(data)
- unix.Munmap(sharedData)
- unix.Close(efd)
- return err
- }
-
// Initialize state based on buffers.
r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData))
r.data = data
@@ -105,7 +95,13 @@ func (r *rx) cleanup() {
unix.Munmap(r.data)
unix.Munmap(r.sharedData)
- unix.Close(r.eventFD)
+ r.eventFD.Close()
+}
+
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (r *rx) notify() {
+ r.eventFD.Notify()
}
// postAndReceive posts the provided buffers (if any), and then tries to read
@@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
if len(b) != 0 && !r.q.PostBuffers(b) {
r.q.EnableNotification()
for !r.q.PostBuffers(b) {
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
@@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
}
// Wait for notification.
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go
new file mode 100644
index 000000000..6ea21ffd1
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_rx.go
@@ -0,0 +1,142 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+type serverRx struct {
+ // packetPipe represents the receive end of the pipe that carries the packet
+ // descriptors sent by the client.
+ packetPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that will carry
+ // completion notifications from the server to the client.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when transmission is completed.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all state needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverRx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ packetPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ s.packetPipe.Init(packetPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ cu.Release()
+ return nil
+}
+
+func (s *serverRx) cleanup() {
+ unix.Munmap(s.packetPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// completionNotificationSize is size in bytes of a completion notification sent
+// on the completion queue after a transmitted packet has been handled.
+const completionNotificationSize = 8
+
+// receive receives a single packet from the packetPipe.
+func (s *serverRx) receive() []byte {
+ desc := s.packetPipe.Pull()
+ if desc == nil {
+ return nil
+ }
+
+ pktInfo := queue.DecodeTxPacketHeader(desc)
+ contents := make([]byte, 0, pktInfo.Size)
+ toCopy := pktInfo.Size
+ for i := 0; i < pktInfo.BufferCount; i++ {
+ txBuf := queue.DecodeTxBufferHeader(desc, i)
+ if txBuf.Size <= toCopy {
+ contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...)
+ toCopy -= txBuf.Size
+ continue
+ }
+ contents = append(contents, s.data[txBuf.Offset:][:toCopy]...)
+ break
+ }
+
+ // Flush to let peer know that slots queued for transmission have been handled
+ // and its free to reuse the slots.
+ s.packetPipe.Flush()
+ // Encode packet completion.
+ b := s.completionPipe.Push(completionNotificationSize)
+ queue.EncodeTxCompletion(b, pktInfo.ID)
+ s.completionPipe.Flush()
+ return contents
+}
+
+func (s *serverRx) waitForPackets() {
+ s.eventFD.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go
new file mode 100644
index 000000000..13a82903f
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_tx.go
@@ -0,0 +1,175 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// serverTx represents the server end of the sharedmem queue and is used to send
+// packets to the peer in the buffers posted by the peer in the fillPipe.
+type serverTx struct {
+ // fillPipe represents the receive end of the pipe that carries the RxBuffers
+ // posted by the peer.
+ fillPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that carries the
+ // descriptors for filled RxBuffers.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when fill requests are fulfilled.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all tstate needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverTx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ fillPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ cu.Release()
+
+ s.fillPipe.Init(fillPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ return nil
+}
+
+func (s *serverTx) cleanup() {
+ unix.Munmap(s.fillPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// fillPacket copies the data in the provided views into buffers pulled from the
+// fillPipe and returns a slice of RxBuffers that contain the copied data as
+// well as the total number of bytes copied.
+//
+// To avoid allocations the filledBuffers are appended to the buffers slice
+// which will be grown as required.
+func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) {
+ filledBuffers = buffers[:0]
+ // fillBuffer copies as much of the views as possible into the provided buffer
+ // and returns any left over views (if any).
+ fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) {
+ if len(views) == 0 {
+ return nil
+ }
+ availBytes := buffer.Size
+ copied := uint64(0)
+ for availBytes > 0 && len(views) > 0 {
+ n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0])
+ views[0].TrimFront(n)
+ if !views[0].IsEmpty() {
+ break
+ }
+ views = views[1:]
+ copied += uint64(n)
+ availBytes -= uint32(n)
+ }
+ buffer.Size = uint32(copied)
+ return views
+ }
+
+ for len(views) > 0 {
+ var b []byte
+ // Spin till we get a free buffer reposted by the peer.
+ for {
+ if b = s.fillPipe.Pull(); b != nil {
+ break
+ }
+ }
+ rxBuffer := queue.DecodeRxBufferHeader(b)
+ // Copy the packet into the posted buffer.
+ views = fillBuffer(&rxBuffer, views)
+ totalCopied += rxBuffer.Size
+ filledBuffers = append(filledBuffers, rxBuffer)
+ }
+
+ return filledBuffers, totalCopied
+}
+
+func (s *serverTx) transmit(views []buffer.View) bool {
+ buffers := make([]queue.RxBuffer, 8)
+ buffers, totalCopied := s.fillPacket(views, buffers)
+ b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers)))
+ if b == nil {
+ return false
+ }
+ queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */)
+ for i := 0; i < len(buffers); i++ {
+ queue.EncodeRxCompletionBuffer(b, i, buffers[i])
+ }
+ s.completionPipe.Flush()
+ s.fillPipe.Flush()
+ return true
+}
+
+func (s *serverTx) notify() {
+ s.eventFD.Notify()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 66efe6472..b75522a51 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -24,14 +24,16 @@
package sharedmem
import (
+ "fmt"
"sync/atomic"
- "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,7 +49,7 @@ type QueueConfig struct {
// EventFD is a file descriptor for the event that is signaled when
// data is becomes available in this queue.
- EventFD int
+ EventFD eventfd.Eventfd
// TxPipeFD is a file descriptor for the tx pipe associated with the
// queue.
@@ -63,16 +65,97 @@ type QueueConfig struct {
SharedDataFD int
}
+// FDs returns the FD's in the QueueConfig as a slice of ints. This must
+// be used in conjunction with QueueConfigFromFDs to ensure the order
+// of FDs matches when reconstructing the config when serialized or sent
+// as part of control messages.
+func (q *QueueConfig) FDs() []int {
+ return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD}
+}
+
+// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each
+// entry represents an file descriptor. The order of FDs in the slice must be in
+// the order specified below for the config to be valid. QueueConfig.FDs()
+// should be used when the config needs to be serialized or sent as part of a
+// control message to ensure the correct order.
+func QueueConfigFromFDs(fds []int) (QueueConfig, error) {
+ if len(fds) != 5 {
+ return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds))
+ }
+ return QueueConfig{
+ DataFD: fds[0],
+ EventFD: eventfd.Wrap(fds[1]),
+ TxPipeFD: fds[2],
+ RxPipeFD: fds[3],
+ SharedDataFD: fds[4],
+ }, nil
+}
+
+// Options specify the details about the sharedmem endpoint to be created.
+type Options struct {
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // BufferSize is the size of each scatter/gather buffer that will hold packet
+ // data.
+ //
+ // NOTE: This directly determines number of packets that can be held in
+ // the ring buffer at any time. This does not have to be sized to the MTU as
+ // the shared memory queue design allows usage of more than one buffer to be
+ // used to make up a given packet.
+ BufferSize uint32
+
+ // LinkAddress is the link address for this endpoint (required).
+ LinkAddress tcpip.LinkAddress
+
+ // TX is the transmit queue configuration for this shared memory endpoint.
+ TX QueueConfig
+
+ // RX is the receive queue configuration for this shared memory endpoint.
+ RX QueueConfig
+
+ // PeerFD is the fd for the connected peer which can be used to detect
+ // peer disconnects.
+ PeerFD int
+
+ // OnClosed is a function that is called when the endpoint is being closed
+ // (probably due to peer going away)
+ OnClosed func(err tcpip.Error)
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
+}
+
type endpoint struct {
// mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
mtu uint32
// bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
bufferSize uint32
// addr is the local address of this endpoint.
+ // addr is immutable.
addr tcpip.LinkAddress
+ // peerFD is an fd to the peer that can be used to detect when the
+ // peer is gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
// rx is the receive queue.
rx rx
@@ -83,34 +166,55 @@ type endpoint struct {
// Wait group used to indicate that all workers have stopped.
completed sync.WaitGroup
+ // onClosed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ onClosed func(tcpip.Error)
+
// mu protects the following fields.
mu sync.Mutex
// tx is the transmit queue.
+ // +checklocks:mu
tx tx
// workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
workerStarted bool
}
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
+func New(opts Options) (stack.LinkEndpoint, error) {
e := &endpoint{
- mtu: mtu,
- bufferSize: bufferSize,
- addr: addr,
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
}
- if err := e.tx.init(bufferSize, &tx); err != nil {
+ if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil {
return nil, err
}
- if err := e.rx.init(bufferSize, &rx); err != nil {
+ if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil {
e.tx.cleanup()
return nil, err
}
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
return e, nil
}
@@ -119,13 +223,13 @@ func (e *endpoint) Close() {
// Tell dispatch goroutine to stop, then write to the eventfd so that
// it wakes up in case it's sleeping.
atomic.StoreUint32(&e.stopRequested, 1)
- unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ e.rx.eventFD.Notify()
// Cleanup the queues inline if the worker hasn't started yet; we also
// know it won't start from now on because stopRequested is set to 1.
e.mu.Lock()
+ defer e.mu.Unlock()
workerPresent := e.workerStarted
- e.mu.Unlock()
if !workerPresent {
e.tx.cleanup()
@@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
e.workerStarted = true
e.completed.Add(1)
+
+ // Spin up a goroutine to monitor for peer shutdown.
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ go func() {
+ defer e.completed.Done()
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any data
+ // transfer and this Read should only return if the peer is shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ }()
+ }
+
// Link endpoints are not savable. When transportation endpoints
// are saved, they stop sending outgoing packets and all
// incoming packets are rejected.
@@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool {
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *endpoint) MTU() uint32 {
- return e.mtu - header.EthernetMinimumSize
+ return e.mtu - e.hdrSize
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
}
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
// ethernet frame header size.
-func (*endpoint) MaxHeaderLength() uint16 {
- return header.EthernetMinimumSize
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
@@ -205,17 +325,15 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WriteRawPacket implements stack.LinkEndpoint.
func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+// +checklocks:e.mu
+func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
views := pkt.Views()
// Transmit the packet.
- e.mu.Lock()
ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
if !ok {
return &tcpip.ErrWouldBlock{}
}
@@ -223,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
return nil
}
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
+func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
}
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
@@ -268,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
Data: buffer.View(b).ToVectorisedView(),
})
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- continue
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
}
- eth := header.Ethernet(hdr)
// Send packet up the stack.
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
}
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
// Clean state.
e.tx.cleanup()
e.rx.cleanup()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go
new file mode 100644
index 000000000..43c5b8c63
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go
@@ -0,0 +1,344 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type serverEndpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ // addr is immutable
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx serverRx
+
+ // stopRequested is to be accessed atomically only, and determines if the
+ // worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // peerFD is an fd to the peer that can be used to detect when the peer is
+ // gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
+ // onClosed is a function to be called when the FD's peer (if any) closes its
+ // end of the communication pipe.
+ onClosed func(tcpip.Error)
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ // +checklocks:mu
+ tx serverTx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
+ workerStarted bool
+}
+
+// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be
+// broken up into buffers of "bufferSize" bytes.
+func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) {
+ e := &serverEndpoint{
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
+ }
+
+ if err := e.tx.init(&opts.RX); err != nil {
+ return nil, err
+ }
+
+ if err := e.rx.init(&opts.TX); err != nil {
+ e.tx.cleanup()
+ return nil, err
+ }
+
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
+
+ return e, nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *serverEndpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes
+ // up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ e.rx.eventFD.Notify()
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also know it
+ // won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ workerPresent := e.workerStarted
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
+func (e *serverEndpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ // Spin up a goroutine to monitor for peer shutdown.
+ go func() {
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any
+ // data transfer and this Read should only return if the peer is
+ // shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ e.completed.Done()
+ }()
+ }
+ // Link endpoints are not savable. When transportation endpoints are saved,
+ // they stop sending outgoing packets and all incoming packets are rejected.
+ go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
+ }
+ e.mu.Unlock()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *serverEndpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *serverEndpoint) MTU() uint32 {
+ return e.mtu - e.hdrSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (e *serverEndpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: remote,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket
+func (e *serverEndpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ views := pkt.Views()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+ e.tx.notify()
+ return nil
+}
+
+// +checklocks:e.mu
+func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
+
+ views := pkt.Views()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+
+ return nil
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ // Transmit the packet.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *serverEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ b := e.rx.receive()
+ if b == nil {
+ e.rx.waitForPackets()
+ continue
+ }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
+ }
+ // Send packet up the stack.
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
new file mode 100644
index 000000000..1bc58614e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem_server_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+ localIPv4Address = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02")
+ serverPort = 10001
+
+ defaultMTU = 1500
+ defaultBufferSize = 1500
+)
+
+type stackOptions struct {
+ ep stack.LinkEndpoint
+ addr tcpip.Address
+}
+
+func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) {
+ st := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ })
+ nicID := tcpip.NICID(1)
+ sniffEP := sniffer.New(stackOpts.ep)
+ opts := stack.NICOptions{Name: "eth0"}
+ if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil {
+ return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err)
+ }
+
+ // Add Protocol Address.
+ protocolNum := ipv4.ProtocolNumber
+ routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}
+ if len(stackOpts.addr) == 16 {
+ routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}}
+ protocolNum = ipv6.ProtocolNumber
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: protocolNum,
+ AddressWithPrefix: stackOpts.addr.WithPrefix(),
+ }
+ if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ // Setup route table.
+ st.SetRouteTable(routeTable)
+
+ return st, nil
+}
+
+func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.New(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: localLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: remoteLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+type testContext struct {
+ clientStk *stack.Stack
+ serverStk *stack.Stack
+ peerFDs [2]int
+}
+
+func newTestContext(t *testing.T) *testContext {
+ peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ t.Fatalf("failed to create peerFDs: %s", err)
+ }
+ q, err := sharedmem.NewQueuePair()
+ if err != nil {
+ t.Fatalf("failed to create sharedmem queue: %s", err)
+ }
+ clientStack, err := newClientStack(t, q, peerFDs[0])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ t.Fatalf("failed to create client stack: %s", err)
+ }
+ serverStack, err := newServerStack(t, q, peerFDs[1])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ clientStack.Close()
+ t.Fatalf("failed to create server stack: %s", err)
+ }
+ return &testContext{
+ clientStk: clientStack,
+ serverStk: serverStack,
+ peerFDs: peerFDs,
+ }
+}
+
+func (ctx *testContext) cleanup() {
+ unix.Close(ctx.peerFDs[0])
+ unix.Close(ctx.peerFDs[1])
+ ctx.clientStk.Close()
+ ctx.serverStk.Close()
+}
+
+func TestServerRoundTrip(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+ listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
+ l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("failed to start TCP Listener: %s", err)
+ }
+ defer l.Close()
+ var responseString = "response"
+ go func() {
+ http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(responseString))
+ }))
+ }()
+
+ dialFunc := func(address, protocol string) (net.Conn, error) {
+ return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if got, want := string(body), responseString; got != want {
+ t.Fatalf("unexpected response got: %s, want: %s", got, want)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index d6d953085..a49f5f87d 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -19,9 +19,7 @@ package sharedmem
import (
"bytes"
- "io/ioutil"
"math/rand"
- "os"
"strings"
"testing"
"time"
@@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
t: t,
packetCh: make(chan struct{}, 1000000),
}
- c.txCfg = createQueueFDs(t, queueSizes{
+ c.txCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
-
- c.rxCfg = createQueueFDs(t, queueSizes{
+ if err != nil {
+ t.Fatalf("createQueueFDs for tx failed: %s", err)
+ }
+ c.rxCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
+ if err != nil {
+ t.Fatalf("createQueueFDs for rx failed: %s", err)
+ }
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(Options{
+ MTU: mtu,
+ BufferSize: bufferSize,
+ LinkAddress: addr,
+ TX: c.txCfg,
+ RX: c.rxCfg,
+ PeerFD: -1,
+ })
if err != nil {
t.Fatalf("New failed: %v", err)
}
@@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.
func (c *testContext) cleanup() {
c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
+ closeFDs(c.txCfg)
+ closeFDs(c.rxCfg)
c.txq.cleanup()
c.rxq.cleanup()
}
@@ -191,69 +201,6 @@ func shuffle(b []int) {
}
}
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
- if !ok {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- unix.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := unix.Ftruncate(fd, size); err != nil {
- unix.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- unix.Close(c.DataFD)
- unix.Close(c.EventFD)
- unix.Close(c.TxPipeFD)
- unix.Close(c.RxPipeFD)
- unix.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
// TestSimpleSend sends 1000 packets with random header and payload sizes,
// then checks that the right payload is received on the shared memory queues.
func TestSimpleSend(t *testing.T) {
@@ -263,6 +210,7 @@ func TestSimpleSend(t *testing.T) {
// Prepare route.
var r stack.RouteInfo
r.RemoteLinkAddress = remoteLinkAddr
+ r.LocalLinkAddress = localLinkAddr
for iters := 1000; iters > 0; iters-- {
func() {
@@ -280,8 +228,11 @@ func TestSimpleSend(t *testing.T) {
Data: data.ToVectorisedView(),
})
copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf)
-
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ // Every PacketBuffer must have these set:
+ // See nic.writePacket.
+ pkt.EgressRoute = r
+ pkt.NetworkProtocolNumber = proto
if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -350,8 +301,11 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
// the minimum size of the ethernet header.
ReserveHeaderBytes: header.EthernetMinimumSize,
})
-
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ // Every PacketBuffer must have these set:
+ // See nic.writePacket.
+ pkt.EgressRoute = r
+ pkt.NetworkProtocolNumber = proto
if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -672,7 +626,7 @@ func TestSimpleReceive(t *testing.T) {
// Push completion.
c.pushRxCompletion(uint32(len(contents)), bufs)
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
@@ -718,7 +672,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete the buffer.
c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for it to be reposted.
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
@@ -734,7 +688,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete with two buffers.
c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for them to be reposted.
for j := 0; j < 2; j++ {
@@ -759,7 +713,7 @@ func TestReceivePostingIsFull(t *testing.T) {
first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that packet is received.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
@@ -768,7 +722,7 @@ func TestReceivePostingIsFull(t *testing.T) {
second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that no packet is received yet, as the worker is blocked trying
// to repost.
@@ -781,7 +735,7 @@ func TestReceivePostingIsFull(t *testing.T) {
// Flush tx queue, which will allow the first buffer to be reposted,
// and the second completion to be pulled.
c.rxq.tx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that second packet completes.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
@@ -803,7 +757,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) {
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be indicated.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
index f7e816a41..d974c266e 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
@@ -15,7 +15,12 @@
package sharedmem
import (
+ "fmt"
+ "reflect"
"unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/memutil"
)
// sharedDataPointer converts the shared data slice into a pointer so that it
@@ -23,3 +28,31 @@ import (
func sharedDataPointer(sharedData []byte) *uint32 {
return (*uint32)(unsafe.Pointer(&sharedData[0:4][0]))
}
+
+// getBuffer returns a memory region mapped to the full contents of the given
+// file descriptor.
+func getBuffer(fd int) ([]byte, error) {
+ var s unix.Stat_t
+ if err := unix.Fstat(fd, &s); err != nil {
+ return nil, err
+ }
+
+ // Check that size doesn't overflow an int.
+ if s.Size > int64(^uint(0)>>1) {
+ return nil, unix.EDOM
+ }
+
+ addr, err := memutil.MapFile(0 /* addr */, uintptr(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE, uintptr(fd), 0 /*offset*/)
+ if err != nil {
+ return nil, fmt.Errorf("failed to map memory for buffer fd: %d, error: %s", fd, err)
+ }
+
+ // Use unsafe to conver addr into a []byte.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ hdr.Data = addr
+ hdr.Len = int(s.Size)
+ hdr.Cap = int(s.Size)
+
+ return b, nil
+}
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index e3210051f..d6c61afee 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -18,6 +18,7 @@ import (
"math"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -28,10 +29,12 @@ const (
// tx holds all state associated with a tx queue.
type tx struct {
- data []byte
- q queue.Tx
- ids idManager
- bufs bufferManager
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+ eventFD eventfd.Eventfd
+ sharedDataFD int
}
// init initializes all state needed by the tx queue based on the information
@@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error {
t.ids.init()
t.bufs.init(0, len(data), int(mtu))
t.data = data
-
+ t.eventFD = c.EventFD
+ t.sharedDataFD = c.SharedDataFD
return nil
}
@@ -142,20 +146,10 @@ func (t *tx) transmit(bufs ...buffer.View) bool {
return true
}
-// getBuffer returns a memory region mapped to the full contents of the given
-// file descriptor.
-func getBuffer(fd int) ([]byte, error) {
- var s unix.Stat_t
- if err := unix.Fstat(fd, &s); err != nil {
- return nil, err
- }
-
- // Check that size doesn't overflow an int.
- if s.Size > int64(^uint(0)>>1) {
- return nil, unix.EDOM
- }
-
- return unix.Mmap(fd, 0, int(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE)
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (t *tx) notify() {
+ t.eventFD.Notify()
}
// idDescriptor is used by idManager to either point to a tx buffer (in case
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6515c31e5..e08243547 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -272,7 +272,6 @@ type protocol struct {
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
-func (p *protocol) DefaultPrefixLen() int { return 0 }
func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 5fcbfeaa2..061cc35ae 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext
t.Fatalf("CreateNIC failed: %s", err)
}
- if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %s", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: stackAddr.WithPrefix(),
+ }
+ if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
tc.s.SetRouteTable([]tcpip.Route{{
@@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 2179302d3..87f650661 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -233,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv4.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
@@ -249,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv6.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
@@ -272,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c
}
v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
}
v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
}
return s, e
@@ -713,8 +725,8 @@ func TestReceive(t *testing.T) {
if !ok {
t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
}
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
} else {
ep.DecRef()
}
@@ -885,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -971,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1237,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1304,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name string
protoFactory stack.NetworkProtocolFactory
protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.Address
+ nicAddr tcpip.AddressWithPrefix
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
@@ -1314,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1355,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with IHL too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1379,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1397,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 minimum size",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1433,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
@@ -1478,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options and data across views",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
@@ -1519,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
@@ -1559,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 with extension header",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
@@ -1604,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 minimum size",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1639,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 too small",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1663,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}{
{
name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))),
},
{
name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))),
},
}
@@ -1680,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.protoNum,
+ AddressWithPrefix: test.nicAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
- r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
}
defer r.Release()
@@ -2072,8 +2088,12 @@ func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddressWithPrefix(nicID, test.proto, test.addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, test.proto, test.addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.proto,
+ AddressWithPrefix: test.addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 2aa38eb98..3eff0bbd8 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -167,23 +167,22 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
p := hdr.TransportProtocol()
dstAddr := hdr.DestinationAddress()
// Skip the ip header, then deliver the error.
- pkt.Data().DeleteFront(hlen)
+ if _, ok := pkt.Data().Consume(hlen); !ok {
+ panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen))
+ }
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
- // ICMP packets don't have their TransportHeader fields set. See
- // icmp/protocol.go:protocol.Parse for a full explanation.
- v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
- if !ok {
+ h := header.ICMPv4(pkt.TransportHeader().View())
+ if len(h) < header.ICMPv4MinimumSize {
received.invalid.Increment()
return
}
- h := header.ICMPv4(v)
// Only do in-stack processing if the checksum is correct.
- if pkt.Data().AsRange().Checksum() != 0xffff {
+ if header.Checksum(h, pkt.Data().AsRange().Checksum()) != 0xffff {
received.invalid.Increment()
// It's possible that a raw socket expects to receive this regardless
// of checksum errors. If it's an echo request we know it's safe because
@@ -240,20 +239,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.echoRequest.Increment()
- sent := e.stats.icmp.packetsSent
- if !e.protocol.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return
- }
-
// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
- // be modifying fields.
+ // be modifying fields. Both the ICMP header (with its type modified to
+ // EchoReply) and payload are reused in the reply packet.
//
// TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
- replyData := pkt.Data().AsRange().ToOwnedView()
+ replyData := stack.PayloadSince(pkt.TransportHeader())
ipHdr := header.IPv4(pkt.NetworkHeader().View())
localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast
@@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
defer r.Release()
+ sent := e.stats.icmp.packetsSent
+ if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) {
+ sent.rateLimited.Increment()
+ return
+ }
+
// TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
// header information, we may have to change this code to handle the
// ICMP header no longer being in the data buffer.
@@ -331,6 +331,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.echoReply.Increment()
+ // ICMP sockets expect the ICMP header to be present, so we don't consume
+ // the ICMP header.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
@@ -338,7 +340,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
mtu := h.MTU()
code := h.Code()
- pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
@@ -562,31 +563,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) {
- // TODO(gvisor.dev/issue/3810):
- // Unfortunately the current stack pretty much always has ICMPv4 headers
- // in the Data section of the packet but there is no guarantee that is the
- // case. If this is the case grab the header to make it like all other
- // packet types. When this is cleaned up the Consume should be removed.
- if transportHeader.IsEmpty() {
- var ok bool
- transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize)
- if !ok {
- return nil
- }
- } else if transportHeader.Size() < header.ICMPv4MinimumSize {
- return nil
- }
// We need to decide to explicitly name the packets we can respond to or
// the ones we can not respond to. The decision is somewhat arbitrary and
// if problems arise this could be reversed. It was judged less of a breach
@@ -606,6 +586,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) {
+ switch reason := reason.(type) {
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonProtoUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetworkUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonFragmentationNeeded:
+ return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0
+ case *icmpReasonTTLExceeded:
+ return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0
+ case *icmpReasonParamProblem:
+ return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType, icmpCode) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
// Now work out how much of the triggering packet we should return.
// As per RFC 1812 Section 4.3.2.3
//
@@ -658,44 +667,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonProtoUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetworkUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4NetUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4HostUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonFragmentationNeeded:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
- counter = sent.dstUnreachable
- case *icmpReasonTTLExceeded:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.timeExceeded
- case *icmpReasonParamProblem:
- icmpHdr.SetType(header.ICMPv4ParamProblem)
- icmpHdr.SetCode(header.ICMPv4UnusedCode)
- icmpHdr.SetPointer(reason.pointer)
- counter = sent.paramProblem
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetPointer(pointer)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4bd6f462e..c6576fcbc 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// cycles.
func TestIGMPV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
- addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
@@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) {
// The initial set of IGMP reports that were queued should be sent once an
// address is assigned.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackAddr,
+ PrefixLen: defaultPrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if got := reportStat.Value(); got != 1 {
t.Errorf("got reportStat.Value() = %d, want = 1", got)
@@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e, s, _ := createStack(t, true)
for _, address := range test.stackAddresses {
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: address,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
stats := s.Stats()
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e2472c851..d1d509702 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -240,7 +247,7 @@ func (e *endpoint) Enable() tcpip.Error {
}
// Create an endpoint to receive broadcast packets on this interface.
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
return err
}
@@ -419,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -432,7 +439,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// We should do this for every packet, rather than only NATted packets, but
// removing this check short circuits broadcasts before they are sent out to
// other hosts.
- if pkt.NatDone {
+ if pkt.DNATDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
@@ -459,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -542,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -569,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -710,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -737,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv4(newPkt.NetworkHeader().View())
// As per RFC 791 page 30, Time to Live,
//
@@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// Even if no local information is available on the time actually
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
+ // We perform a full checksum as we may have updated options above. The IP
+ // header is relatively small so this is not expected to be an expensive
+ // operation.
+ newHdr.SetChecksum(0)
+ newHdr.SetChecksum(^newHdr.CalculateChecksum())
+
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -826,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -925,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -969,7 +984,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
}
proto := h.Protocol()
- resPkt, _, ready, err := e.protocol.fragmentation.Process(
+ resPkt, transProtoNum, ready, err := e.protocol.fragmentation.Process(
// As per RFC 791 section 2.3, the identification value is unique
// for a source-destination pair and protocol.
fragmentation.FragmentID{
@@ -1000,6 +1015,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
h.SetTotalLength(uint16(pkt.Data().Size() + len(h)))
h.SetFlagsFragmentOffset(0, 0)
+ e.protocol.parseTransport(pkt, tcpip.TransportProtocolNumber(transProtoNum))
+
// Now that the packet is reassembled, it can be sent to raw sockets.
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
@@ -1075,11 +1092,11 @@ func (e *endpoint) Close() {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err == nil {
e.mu.igmp.sendQueuedReports()
}
@@ -1200,6 +1217,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv4Type]struct{}
}
// defaultTTL is the current default TTL for the protocol. Only the
@@ -1226,11 +1246,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
-// DefaultPrefixLen returns the IPv4 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv4AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
@@ -1297,19 +1312,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool)
}
if hasTransportHdr {
- switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
- case stack.ParsedOK:
- case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
- // The transport layer will handle unknown protocols and transport layer
- // parsing errors.
- default:
- panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
- }
+ p.parseTransport(pkt, transProtoNum)
}
return h, true
}
+func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) {
+ if transProtoNum == header.ICMPv4ProtocolNumber {
+ // The transport layer will handle transport layer parsing errors.
+ _ = parse.ICMPv4(pkt)
+ return
+ }
+
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+}
+
// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
if ok := parse.IPv4(pkt); !ok {
@@ -1320,6 +1345,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type and code may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool {
+ // Mimic linux and never rate limit for PMTU discovery.
+ // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288
+ if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded {
+ return true
+ }
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
@@ -1399,6 +1441,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
+ // Set ICMP rate limiting to Linux defaults.
+ // See https://man7.org/linux/man-pages/man7/icmp.7.html.
+ p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{
+ header.ICMPv4DstUnreachable: struct{}{},
+ header.ICMPv4SrcQuench: struct{}{},
+ header.ICMPv4TimeExceeded: struct{}{},
+ header.ICMPv4ParamProblem: struct{}{},
+ }
return p
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 73407be67..ef91245d7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if err := ep.Connect(randomAddr); err != nil {
t.Errorf("Connect failed: %v", err)
@@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
expectedEmittedPacketCount := 1
@@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
// Default routes for IPv4 so ICMP can find a route to the remote
@@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
for _, f := range test.fragments {
@@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) {
const (
nicID = 1
- addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1
+ addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2
+ addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3
)
// Build and return a UDP header containing payload.
@@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
)
- if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask(header.IPv4Broadcast)
@@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
@@ -3349,3 +3373,139 @@ func TestCloseLocking(t *testing.T) {
}
}()
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv4ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.UDPProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index f99cbf8f3..f814926a3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -51,6 +51,7 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 94caaae6c..adfc8d8da 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().DeleteFront(header.IPv6MinimumSize)
+ if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok {
+ panic("could not consume IPv6MinimumSize bytes")
+ }
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
+ if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok {
+ panic("could not consume IPv6FragmentHeaderSize bytes")
+ }
}
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
@@ -270,7 +274,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD {
return false
}
- if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize {
+ if pkt.TransportHeader().View().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize {
return false
}
if iph.HopLimit() != header.MLDHopLimit {
@@ -285,20 +289,17 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) {
sent := e.stats.icmp.packetsSent
received := e.stats.icmp.packetsReceived
- // ICMP packets don't have their TransportHeader fields set. See
- // icmp/protocol.go:protocol.Parse for a full explanation.
- v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize)
- if !ok {
+ h := header.ICMPv6(pkt.TransportHeader().View())
+ if len(h) < header.ICMPv6MinimumSize {
received.invalid.Increment()
return
}
- h := header.ICMPv6(v)
iph := header.IPv6(pkt.NetworkHeader().View())
srcAddr := iph.SourceAddress()
dstAddr := iph.DestinationAddress()
// Validate ICMPv6 checksum before processing the packet.
- payload := pkt.Data().AsRange().SubRange(len(h))
+ payload := pkt.Data().AsRange()
if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: h,
Src: srcAddr,
@@ -325,28 +326,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.packetTooBig.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize)
- if !ok {
- received.invalid.Increment()
- return
- }
- networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
+ networkMTU, err := calculateNetworkMTU(h.MTU(), header.IPv6MinimumSize)
if err != nil {
networkMTU = 0
}
- pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
received.dstUnreachable.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize)
- if !ok {
- received.invalid.Increment()
- return
- }
- code := header.ICMPv6(hdr).Code()
- pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
- switch code {
+ switch h.Code() {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
case header.ICMPv6PortUnreachable:
@@ -354,16 +342,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
case header.ICMPv6NeighborSolicit:
received.neighborSolicit.Increment()
- if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize {
+ if !isNDPValid() || len(h) < header.ICMPv6NeighborSolicitMinimumSize {
received.invalid.Increment()
return
}
- // The remainder of payload must be only the neighbor solicitation, so
- // payload.AsView() always returns the solicitation. Per RFC 6980 section 5,
- // NDP messages cannot be fragmented. Also note that in the common case NDP
- // datagrams are very small and AsView() will not incur allocations.
- ns := header.NDPNeighborSolicit(payload.AsView())
+ ns := header.NDPNeighborSolicit(h.MessageBody())
targetAddr := ns.TargetAddress()
// As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
@@ -576,16 +560,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6NeighborAdvert:
received.neighborAdvert.Increment()
- if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize {
+ if !isNDPValid() || len(h) < header.ICMPv6NeighborAdvertMinimumSize {
received.invalid.Increment()
return
}
- // The remainder of payload must be only the neighbor advertisement, so
- // payload.AsView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and AsView() will not incur allocations.
- na := header.NDPNeighborAdvert(payload.AsView())
+ na := header.NDPNeighborAdvert(h.MessageBody())
it, err := na.Options().Iter(false /* check */)
if err != nil {
@@ -672,12 +652,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6EchoRequest:
received.echoRequest.Increment()
- icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
- if !ok {
- received.invalid.Increment()
- return
- }
-
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
localAddr := dstAddr
@@ -692,13 +666,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
defer r.Release()
+ if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) {
+ sent.rateLimited.Increment()
+ return
+ }
+
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data().ExtractVV(),
})
icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- copy(icmp, icmpHdr)
+ copy(icmp, h)
icmp.SetType(header.ICMPv6EchoReply)
dataRange := replyPkt.Data().AsRange()
icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -720,7 +699,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6EchoReply:
received.echoReply.Increment()
- if pkt.Data().Size() < header.ICMPv6EchoMinimumSize {
+ if len(h) < header.ICMPv6EchoMinimumSize {
received.invalid.Increment()
return
}
@@ -740,7 +719,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// Is the NDP payload of sufficient size to hold a Router Solictation?
- if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
+ if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
received.invalid.Increment()
return
}
@@ -750,9 +729,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- // Note that in the common case NDP datagrams are very small and AsView()
- // will not incur allocations.
- rs := header.NDPRouterSolicit(payload.AsView())
+ rs := header.NDPRouterSolicit(h.MessageBody())
it, err := rs.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
@@ -796,7 +773,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// Is the NDP payload of sufficient size to hold a Router Advertisement?
- if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
+ if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
received.invalid.Increment()
return
}
@@ -810,9 +787,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- // Note that in the common case NDP datagrams are very small and AsView()
- // will not incur allocations.
- ra := header.NDPRouterAdvert(payload.AsView())
+ ra := header.NDPRouterAdvert(h.MessageBody())
it, err := ra.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
@@ -890,11 +865,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType {
case header.ICMPv6MulticastListenerQuery:
e.mu.Lock()
- e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView()))
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(h.MessageBody()))
e.mu.Unlock()
case header.ICMPv6MulticastListenerReport:
e.mu.Lock()
- e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView()))
+ e.mu.mld.handleMulticastListenerReport(header.MLD(h.MessageBody()))
e.mu.Unlock()
case header.ICMPv6MulticastListenerDone:
default:
@@ -1174,28 +1149,37 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
- // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
- // Unfortunately at this time ICMP Packets do not have a transport
- // header separated out. It is in the Data part so we need to
- // separate it out now. We will just pretend it is a minimal length
- // ICMP packet as we don't really care if any later bits of a
- // larger ICMP packet are in the header view or in the Data view.
- transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize)
- if !ok {
+ if typ := header.ICMPv6(pkt.TransportHeader().View()).Type(); typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
return nil
}
- typ := header.ICMPv6(transport).Type()
- if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
- return nil
+ }
+
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) {
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonPacketTooBig:
+ return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0
+ case *icmpReasonHopLimitExceeded:
+ return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
+ }()
+
+ if !p.allowICMPReply(icmpType) {
+ sent.rateLimited.Increment()
+ return nil
}
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
@@ -1232,40 +1216,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonParameterProblem:
- icmpHdr.SetType(header.ICMPv6ParamProblem)
- icmpHdr.SetCode(reason.code)
- icmpHdr.SetTypeSpecific(reason.pointer)
- counter = sent.paramProblem
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6AddressUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonPacketTooBig:
- icmpHdr.SetType(header.ICMPv6PacketTooBig)
- icmpHdr.SetCode(header.ICMPv6UnusedCode)
- counter = sent.packetTooBig
- case *icmpReasonHopLimitExceeded:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
- counter = sent.timeExceeded
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetTypeSpecific(typeSpecific)
+
dataRange := newPkt.Data().AsRange()
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 7c2a3e56b..03d9f425c 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -225,8 +226,8 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -407,8 +408,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
+ llProtocolAddr0 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
@@ -416,8 +421,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
+ llProtocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr1.WithPrefix(),
+ }
+ if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err)
}
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -690,8 +699,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -883,8 +896,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1065,8 +1082,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1240,8 +1261,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -1411,12 +1436,14 @@ func TestPacketQueing(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
Clock: clock,
})
+ // Make sure ICMP rate limiting doesn't get in our way.
+ s.SetICMPLimit(rate.Inf)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1669,8 +1696,12 @@ func TestCallsToNeighborCache(t *testing.T) {
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
{
@@ -1704,8 +1735,8 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index d4bd61748..7d3e1fd53 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -761,7 +761,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// We should do this for every packet, rather than only NATted packets, but
// removing this check short circuits broadcasts before they are sent out to
// other hosts.
- if pkt.NatDone {
+ if pkt.DNATDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
@@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv6(newPkt.NetworkHeader().View())
// As per RFC 8200 section 3,
//
@@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -1097,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -1180,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1534,27 +1537,36 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
- // Calculate the number of octets parsed from data. We want to remove all
- // the data except the unparsed portion located at the end, which its size
- // is extHdr.Buf.Size().
+ // Calculate the number of octets parsed from data. We want to consume all
+ // the data except the unparsed portion located at the end, whose size is
+ // extHdr.Buf.Size().
trim := pkt.Data().Size() - extHdr.Buf.Size()
// For unfragmented packets, extHdr still contains the transport header.
- // Get rid of it.
+ // Consume that too.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
trim += pkt.TransportHeader().View().Size()
- pkt.Data().DeleteFront(trim)
+ if _, ok := pkt.Data().Consume(trim); !ok {
+ stats.MalformedPacketsReceived.Increment()
+ return fmt.Errorf("could not consume %d bytes", trim)
+ }
+
+ proto := tcpip.TransportProtocolNumber(extHdr.Identifier)
+ // If the packet was reassembled from a fragment, it will not have a
+ // transport header set yet.
+ if pkt.TransportHeader().View().IsEmpty() {
+ e.protocol.parseTransport(pkt, proto)
+ }
stats.PacketsDelivered.Increment()
- if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
- pkt.TransportProtocolNumber = p
+ if proto == header.ICMPv6ProtocolNumber {
e.handleICMP(pkt, hasFragmentHeader, routerAlert)
} else {
stats.PacketsDelivered.Increment()
- switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
+ switch res := e.dispatcher.DeliverTransportPacket(proto, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC 4443 section 3.1:
@@ -1628,12 +1640,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
defer e.mu.Unlock()
- return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+ return e.addAndAcquirePermanentAddressLocked(addr, properties)
}
// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
@@ -1643,8 +1655,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err != nil {
return nil, err
}
@@ -1987,6 +1999,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv6Type]struct{}
}
ids []uint32
@@ -1998,7 +2013,8 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- fragmentation *fragmentation.Fragmentation
+ fragmentation *fragmentation.Fragmentation
+ icmpRateLimiter *stack.ICMPRateLimiter
}
// Number returns the ipv6 protocol number.
@@ -2011,11 +2027,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen returns the IPv6 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
@@ -2087,6 +2098,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -2149,19 +2167,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool)
}
if hasTransportHdr {
- switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
- case stack.ParsedOK:
- case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
- // The transport layer will handle unknown protocols and transport layer
- // parsing errors.
- default:
- panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
- }
+ p.parseTransport(pkt, transProtoNum)
}
return h, true
}
+func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) {
+ if transProtoNum == header.ICMPv6ProtocolNumber {
+ // The transport layer will handle transport layer parsing errors.
+ _ = parse.ICMPv6(pkt)
+ return
+ }
+
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+}
+
// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
@@ -2172,6 +2200,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
@@ -2268,6 +2308,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
p.SetDefaultTTL(DefaultTTL)
+ // Set default ICMP rate limiting to Linux defaults.
+ //
+ // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big)
+ // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt.
+ defaultIcmpTypes := make(map[header.ICMPv6Type]struct{})
+ for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ {
+ switch i {
+ case header.ICMPv6PacketTooBig:
+ // Do not rate limit packet too big by default.
+ default:
+ defaultIcmpTypes[i] = struct{}{}
+ }
+ }
+ p.mu.icmpRateLimitedTypes = defaultIcmpTypes
+
return p
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index d2a23fd4f..e5286081e 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -41,12 +41,12 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
- addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02")
+ addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03")
// Tests use the extension header identifier values as uint8 instead of
// header.IPv6ExtensionHeaderIdentifier.
@@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err)
}
// Should receive a packet destined to the solicited node address of
// addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr3.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err)
}
// Should still receive a packet destined to the solicited node address of
@@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
@@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Add a default route so that a return packet knows where to go.
@@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) {
func TestFragmentReassemblyTimeout(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
)
- if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
@@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
@@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3341,7 +3373,8 @@ func TestForwarding(t *testing.T) {
ipHeaderLength := header.IPv6MinimumSize
icmpHeaderLength := header.ICMPv6MinimumSize
- totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
+ payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen
+ totalLength := ipHeaderLength + payloadLength
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
@@ -3359,7 +3392,7 @@ func TestForwarding(t *testing.T) {
copy(hdr.Prepend(extHdrLen), extHdrBytes)
ip := header.IPv6(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
+ PayloadLength: uint16(payloadLength),
TransportProtocol: transportProtocol,
HopLimit: test.TTL,
SrcAddr: test.sourceAddr,
@@ -3489,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) {
t.Error(err)
}
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv6EchoRequest)
+ icmpH.SetCode(header.ICMPv6UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH,
+ Src: host2IPv6Addr.AddressWithPrefix.Address,
+ Dst: host1IPv6Addr.AddressWithPrefix.Address,
+ }))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv6ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+
+ // Calculate the UDP checksum and set it.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(nil, sum)
+ udpH.SetChecksum(^udpH.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index bc9cf6999..3e5c438d3 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Note, we will still expect to send a report for the global address's
// solicited node address from the unspecified address as per RFC 3590
// section 4.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ globalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: globalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err)
}
reportCounter++
if got := reportStat.Value(); got != reportCounter {
@@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Adding a link-local address should send a report for its solicited node
// address and globalMulticastAddr.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ linkLocalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err)
}
if dadResolutionTime != 0 {
reportCounter++
@@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 8837d66d8..938427420 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{
+ PEB: stack.FirstPrimaryEndpoint,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ })
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index f0186c64e..8297a7e10 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
@@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
@@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
@@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
@@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) {
checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
))
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
checkDADMsg()
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 1b96b1fb8..26640b7ee 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
+ addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackIPv4Addr,
+ PrefixLen: defaultIPv4PrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, clock
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 009cab643..05b879543 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -146,8 +146,12 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Add default route.
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index c10b19aa0..a72afadda 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -124,13 +124,13 @@ func main() {
log.Fatalf("Bad IP address: %v", addrName)
}
- var addr tcpip.Address
+ var addrWithPrefix tcpip.AddressWithPrefix
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
- addr = tcpip.Address(parsedAddr.To4())
+ addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix()
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
- addr = tcpip.Address(parsedAddr.To16())
+ addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix()
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", addrName)
@@ -176,11 +176,15 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, proto, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address))))
if err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 34ac62444..b0b2d0afd 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -170,10 +170,14 @@ type SocketOptions struct {
// message is passed with incoming packets.
receiveTClassEnabled uint32
- // receivePacketInfoEnabled is used to specify if more inforamtion is
- // provided with incoming packets such as interface index and address.
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv4 packets.
receivePacketInfoEnabled uint32
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv6 packets.
+ receiveIPv6PacketInfoEnabled uint32
+
// hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
// being written have an IP header and the endpoint should not attach an IP
// header.
@@ -360,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) {
storeAtomicBool(&so.receivePacketInfoEnabled, v)
}
+// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0
+}
+
+// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v)
+}
+
// GetHeaderIncluded gets value for IP_HDRINCL option.
func (so *SocketOptions) GetHeaderIncluded() bool {
return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 6c42ab29b..ead36880f 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -48,7 +48,6 @@ go_library(
"hook_string.go",
"icmp_rate_limit.go",
"iptables.go",
- "iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
"neighbor_cache.go",
@@ -133,6 +132,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
+ "conntrack_test.go",
"forwarding_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ae0bb4ace..7e4b5bf74 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We now promote the address.
for i, s := range a.mu.primary {
if s == addrState {
- switch peb {
+ switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
@@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
case NeverPrimaryEndpoint:
a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
@@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// Acquire the address before returning it.
addrState.mu.refs++
- addrState.mu.deprecated = deprecated
- addrState.mu.configType = configType
+ addrState.mu.deprecated = properties.Deprecated
+ addrState.mu.configType = properties.ConfigType
if attemptAddToPrimary {
- switch peb {
+ switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.mu.primary = append(a.mu.primary, addrState)
@@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
a.mu.primary[0] = addrState
}
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
@@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
- ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
- panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 140f146f6..c55f85743 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
}
{
- ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
- t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err)
}
// We don't need the address endpoint.
ep.DecRef()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..a3f403855 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -37,23 +37,9 @@ import (
// Our hash table has 16K buckets.
const numBuckets = 1 << 14
-// Direction of the tuple.
-type direction int
-
-const (
- dirOriginal direction = iota
- dirReply
-)
-
-// Manipulation type for the connection.
-// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
-// DNAT at the same time.
-type manipType int
-
const (
- manipNone manipType = iota
- manipSource
- manipDestination
+ establishedTimeout time.Duration = 5 * 24 * time.Hour
+ unestablishedTimeout time.Duration = 120 * time.Second
)
// tuple holds a connection's identifying and manipulating data in one
@@ -64,13 +50,22 @@ type tuple struct {
// tupleEntry is used to build an intrusive list of tuples.
tupleEntry
- tupleID
-
// conn is the connection tracking entry this tuple belongs to.
conn *conn
- // direction is the direction of the tuple.
- direction direction
+ // reply is true iff the tuple's direction is opposite that of the first
+ // packet seen on the connection.
+ reply bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ tupleID tupleID
+}
+
+func (t *tuple) id() tupleID {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.tupleID
}
// tupleID uniquely identifies a connection in one direction. It currently
@@ -103,50 +98,43 @@ func (ti tupleID) reply() tupleID {
//
// +stateify savable
type conn struct {
+ ct *ConnTrack
+
// original is the tuple in original direction. It is immutable.
original tuple
- // reply is the tuple in reply direction. It is immutable.
+ // reply is the tuple in reply direction.
reply tuple
- // manip indicates if the packet should be manipulated. It is immutable.
- // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
- manip manipType
-
- // tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb. It is immutable.
- tcbHook Hook
-
- // mu protects all mutable state.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // Indicates that the connection has been finalized and may handle replies.
+ //
+ // +checklocks:mu
+ finalized bool
+ // sourceManip indicates the packet's source is manipulated.
+ //
+ // +checklocks:mu
+ sourceManip bool
+ // destinationManip indicates the packet's destination is manipulated.
+ //
+ // +checklocks:mu
+ destinationManip bool
// tcb is TCB control block. It is used to keep track of states
- // of tcp connection and is protected by mu.
+ // of tcp connection.
+ //
+ // +checklocks:mu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
- // is updated by each packet on the connection. It is protected by mu.
+ // is updated by each packet on the connection.
//
- // TODO(gvisor.dev/issue/5939): do not use the ambient clock.
- lastUsed time.Time `state:".(unixTime)"`
-}
-
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
+ // +checklocks:mu
+ lastUsed tcpip.MonotonicTime
}
// timedOut returns whether the connection timed out based on its state.
-func (cn *conn) timedOut(now time.Time) bool {
- const establishedTimeout = 5 * 24 * time.Hour
- const defaultTimeout = 120 * time.Second
- cn.mu.Lock()
- defer cn.mu.Unlock()
+func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -154,22 +142,31 @@ func (cn *conn) timedOut(now time.Time) bool {
}
// Use the same default as Linux, which lets connections in most states
// other than established remain for <= 120 seconds.
- return now.Sub(cn.lastUsed) > defaultTimeout
+ return now.Sub(cn.lastUsed) > unestablishedTimeout
}
// update the connection tracking state.
//
-// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+// +checklocks:cn.mu
+func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
- } else if hook == cn.tcbHook {
- cn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
+ return
+ }
+
+ if reply {
cn.tcb.UpdateStateInbound(tcpHeader)
+ } else {
+ cn.tcb.UpdateStateOutbound(tcpHeader)
}
}
@@ -194,44 +191,37 @@ type ConnTrack struct {
// It is immutable.
seed uint32
+ // clock provides timing used to determine conntrack reapings.
+ clock tcpip.Clock
+
+ mu sync.RWMutex `state:"nosave"`
// mu protects the buckets slice, but not buckets' contents. Only take
// the write lock if you are modifying the slice or saving for S/R.
- mu sync.RWMutex `state:"nosave"`
-
- // buckets is protected by mu.
+ //
+ // +checklocks:mu
buckets []bucket
}
// +stateify savable
type bucket struct {
- // mu protects tuples.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
tuples tupleList
}
-// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
-// TCP header.
-//
-// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
- netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
}
- return tupleID{
- srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
- dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
- netProto: pkt.NetworkProtocolNumber,
- }, nil
+ return nil, false
}
func (ct *ConnTrack) init() {
@@ -240,278 +230,285 @@ func (ct *ConnTrack) init() {
ct.buckets = make([]bucket, numBuckets)
}
-// connFor gets the conn for pkt if it exists, or returns nil
-// if it does not. It returns an error when pkt does not contain a valid TCP
-// header.
-// TODO(gvisor.dev/issue/6168): Support UDP.
-func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil, dirOriginal
+func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
+ netHeader := pkt.Network()
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return nil
+ }
+
+ tid := tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: transportHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}
- return ct.connForTID(tid)
-}
-func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
- bucket := ct.bucket(tid)
- now := time.Now()
+ bktID := ct.bucket(tid)
ct.mu.RLock()
- defer ct.mu.RUnlock()
- ct.buckets[bucket].mu.Lock()
- defer ct.buckets[bucket].mu.Unlock()
-
- // Iterate over the tuples in a bucket, cleaning up any unused
- // connections we find.
- for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
- // Clean up any timed-out connections we happen to find.
- if ct.reapTupleLocked(other, bucket, now) {
- // The tuple expired.
- continue
- }
- if tid == other.tupleID {
- return other.conn, other.direction
- }
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ now := ct.clock.NowMonotonic()
+ if t := bkt.connForTID(tid, now); t != nil {
+ return t
}
- return nil, dirOriginal
-}
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
+ // Make sure a connection wasn't added between when we last checked the
+ // bucket and acquired the bucket's write lock.
+ if t := bkt.connForTIDRLocked(tid, now); t != nil {
+ return t
}
- if hook != Prerouting && hook != Output {
- return nil
+
+ // This is the first packet we're seeing for the connection. Create an entry
+ // for this new connection.
+ conn := &conn{
+ ct: ct,
+ original: tuple{tupleID: tid},
+ reply: tuple{tupleID: tid.reply(), reply: true},
+ lastUsed: now,
}
+ conn.original.conn = conn
+ conn.reply.conn = conn
- replyTID := tid.reply()
- replyTID.srcAddr = address
- replyTID.srcPort = port
+ // For now, we only map an entry for the packet's original tuple as NAT may be
+ // performed on this connection. Until the packet goes through all the hooks
+ // and its final address/port is known, we cannot know what the response
+ // packet's addresses/ports will look like.
+ //
+ // This is okay because the destination cannot send its response until it
+ // receives the packet; the packet will only be received once all the hooks
+ // have been performed.
+ //
+ // See (*conn).finalize.
+ bkt.tuples.PushFront(&conn.original)
+ return &conn.original
+}
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
- }
- conn = newConn(tid, replyTID, manipDestination, hook)
- ct.insertConn(conn)
- return conn
+func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
+ bktID := ct.bucket(tid)
+
+ ct.mu.RLock()
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ return bkt.connForTID(tid, ct.clock.NowMonotonic())
}
-func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Input && hook != Postrouting {
- return nil
+func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple {
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ return bkt.connForTIDRLocked(tid, now)
+}
+
+// +checklocksread:bkt.mu
+func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple {
+ for other := bkt.tuples.Front(); other != nil; other = other.Next() {
+ if tid == other.id() && !other.conn.timedOut(now) {
+ return other
+ }
}
+ return nil
+}
- replyTID := tid.reply()
- replyTID.dstAddr = address
- replyTID.dstPort = port
+func (ct *ConnTrack) finalize(cn *conn) {
+ tid := cn.reply.id()
+ id := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ ct.mu.RLock()
+ bkt := &ct.buckets[id]
+ ct.mu.RUnlock()
+
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ if t := bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic()); t != nil {
+ // Another connection for the reply already exists. We can't do much about
+ // this so we leave the connection cn represents in a state where it can
+ // send packets but its responses will be mapped to some other connection.
+ // This may be okay if the connection only expects to send packets without
+ // any responses.
+ return
}
- conn = newConn(tid, replyTID, manipSource, hook)
- ct.insertConn(conn)
- return conn
+
+ bkt.tuples.PushFront(&cn.reply)
}
-// insertConn inserts conn into the appropriate table bucket.
-func (ct *ConnTrack) insertConn(conn *conn) {
- // Lock the buckets in the correct order.
- tupleBucket := ct.bucket(conn.original.tupleID)
- replyBucket := ct.bucket(conn.reply.tupleID)
- ct.mu.RLock()
- defer ct.mu.RUnlock()
- if tupleBucket < replyBucket {
- ct.buckets[tupleBucket].mu.Lock()
- ct.buckets[replyBucket].mu.Lock()
- } else if tupleBucket > replyBucket {
- ct.buckets[replyBucket].mu.Lock()
- ct.buckets[tupleBucket].mu.Lock()
- } else {
- // Both tuples are in the same bucket.
- ct.buckets[tupleBucket].mu.Lock()
- }
-
- // Now that we hold the locks, ensure the tuple hasn't been inserted by
- // another thread.
- // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
- alreadyInserted := false
- for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
- if other.tupleID == conn.original.tupleID {
- alreadyInserted = true
- break
+func (cn *conn) finalize() {
+ {
+ cn.mu.RLock()
+ finalized := cn.finalized
+ cn.mu.RUnlock()
+ if finalized {
+ return
}
}
- if !alreadyInserted {
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ cn.mu.Lock()
+ finalized := cn.finalized
+ cn.finalized = true
+ cn.mu.Unlock()
+ if finalized {
+ return
}
- // Unlocking can happen in any order.
- ct.buckets[tupleBucket].mu.Unlock()
- if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
- }
+ cn.ct.finalize(cn)
}
-// handlePacket will manipulate the port and address of the packet if the
-// connection exists. Returns whether, after the packet traverses the tables,
-// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
- if pkt.NatDone {
- return false
+// performNAT setups up the connection for the specified NAT.
+//
+// Generally, only the first packet of a connection reaches this method; other
+// other packets will be manipulated without needing to modify the connection.
+func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) {
+ cn.performNATIfNoop(port, address, dnat)
+ cn.handlePacket(pkt, hook, r)
+}
+
+func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ if cn.finalized {
+ return
}
- switch hook {
- case Prerouting, Input, Output, Postrouting:
- default:
- return false
+ if dnat {
+ if cn.destinationManip {
+ return
+ }
+ cn.destinationManip = true
+ } else {
+ if cn.sourceManip {
+ return
+ }
+ cn.sourceManip = true
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ cn.reply.mu.Lock()
+ defer cn.reply.mu.Unlock()
+
+ if dnat {
+ cn.reply.tupleID.srcAddr = address
+ cn.reply.tupleID.srcPort = port
+ } else {
+ cn.reply.tupleID.dstAddr = address
+ cn.reply.tupleID.dstPort = port
+ }
+}
+
+// handlePacket attempts to handle a packet and perform NAT if the connection
+// has had NAT performed on it.
+//
+// Returns true if the packet can skip the NAT table.
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return false
}
- conn, dir := ct.connFor(pkt)
- // Connection not found for the packet.
- if conn == nil {
- // If this is the last hook in the data path for this packet (Input if
- // incoming, Postrouting if outgoing), indicate that a connection should be
- // inserted by the end of this hook.
- return hook == Input || hook == Postrouting
+ fullChecksum := false
+ updatePseudoHeader := false
+ natDone := &pkt.SNATDone
+ dnat := false
+ switch hook {
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+
+ natDone = &pkt.DNATDone
+ dnat = true
+ case Input:
+ case Forward:
+ panic("should not handle packet in the forwarding hook")
+ case Output:
+ natDone = &pkt.DNATDone
+ dnat = true
+ fallthrough
+ case Postrouting:
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ updatePseudoHeader = true
+ } else if rt.RequiresTXTransportChecksum() {
+ fullChecksum = true
+ updatePseudoHeader = true
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %d", hook))
}
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
+ if *natDone {
+ panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt))
}
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
- var newAddr tcpip.Address
- var newPort uint16
-
- updateSRCFields := false
-
- switch hook {
- case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
+ reply := pkt.tuple.reply
+ tid, performManip := func() (tupleID, bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = cn.ct.clock.NowMonotonic()
+ // Update connection state.
+ cn.updateLocked(pkt, reply)
+
+ var tuple *tuple
+ if reply {
+ if dnat {
+ if !cn.sourceManip {
+ return tupleID{}, false
+ }
+ } else if !cn.destinationManip {
+ return tupleID{}, false
}
- pkt.NatDone = true
- }
- case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
+
+ tuple = &cn.original
+ } else {
+ if dnat {
+ if !cn.destinationManip {
+ return tupleID{}, false
+ }
+ } else if !cn.sourceManip {
+ return tupleID{}, false
}
- pkt.NatDone = true
+
+ tuple = &cn.reply
}
- default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
- }
- if !pkt.NatDone {
+
+ return tuple.id(), true
+ }()
+ if !performManip {
return false
}
- fullChecksum := false
- updatePseudoHeader := false
- switch hook {
- case Prerouting, Input:
- case Output, Postrouting:
- // Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
- updatePseudoHeader = true
- } else if r.RequiresTXTransportChecksum() {
- fullChecksum = true
- updatePseudoHeader = true
- }
- default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ newPort := tid.dstPort
+ newAddr := tid.dstAddr
+ if dnat {
+ newPort = tid.srcPort
+ newAddr = tid.srcAddr
}
rewritePacket(
- netHeader,
- tcpHeader,
- updateSRCFields,
+ pkt.Network(),
+ transportHeader,
+ !dnat,
fullChecksum,
updatePseudoHeader,
newPort,
newAddr,
)
- // Update the state of tcb.
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- // Mark the connection as having been used recently so it isn't reaped.
- conn.lastUsed = time.Now()
- // Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
-
- return false
-}
-
-// maybeInsertNoop tries to insert a no-op connection entry to keep connections
-// from getting clobbered when replies arrive. It only inserts if there isn't
-// already a connection for pkt.
-//
-// This should be called after traversing iptables rules only, to ensure that
-// pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
- // If there were a rule applying to this packet, it would be marked
- // with NatDone.
- if pkt.NatDone {
- return
- }
-
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
- return
- }
-
- // This is the first packet we're seeing for the TCP connection. Insert
- // the noop entry (an identity mapping) so that the response doesn't
- // get NATed, breaking the connection.
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return
- }
- conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
- ct.insertConn(conn)
+ *natDone = true
+ return true
}
// bucket gets the conntrack bucket for a tupleID.
@@ -555,7 +552,7 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
const minInterval = 10 * time.Millisecond
const maxInterval = maxFullTraversal / fractionPerReaping
- now := time.Now()
+ now := ct.clock.NowMonotonic()
checked := 0
expired := 0
var idx int
@@ -563,14 +560,20 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
defer ct.mu.RUnlock()
for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
idx = (i + start) % len(ct.buckets)
- ct.buckets[idx].mu.Lock()
- for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ bkt := &ct.buckets[idx]
+ bkt.mu.Lock()
+ for tuple := bkt.tuples.Front(); tuple != nil; {
+ // reapTupleLocked updates tuple's next pointer so we grab it here.
+ nextTuple := tuple.Next()
+
checked++
- if ct.reapTupleLocked(tuple, idx, now) {
+ if ct.reapTupleLocked(tuple, idx, bkt, now) {
expired++
}
+
+ tuple = nextTuple
}
- ct.buckets[idx].mu.Unlock()
+ bkt.mu.Unlock()
}
// We already checked buckets[idx].
idx++
@@ -595,44 +598,51 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// reapTupleLocked tries to remove tuple and its reply from the table. It
// returns whether the tuple's connection has timed out.
//
-// Preconditions:
-// * ct.mu is locked for reading.
-// * bucket is locked.
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+// Precondition: ct.mu is read locked and bkt.mu is write locked.
+// +checklocksread:ct.mu
+// +checklocks:bkt.mu
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
if !tuple.conn.timedOut(now) {
return false
}
- // To maintain lock order, we can only reap these tuples if the reply
- // appears later in the table.
- replyBucket := ct.bucket(tuple.reply())
- if bucket > replyBucket {
+ // To maintain lock order, we can only reap both tuples if the reply appears
+ // later in the table.
+ replyBktID := ct.bucket(tuple.id().reply())
+ tuple.conn.mu.RLock()
+ replyTupleInserted := tuple.conn.finalized
+ tuple.conn.mu.RUnlock()
+ if bktID > replyBktID && replyTupleInserted {
return true
}
- // Don't re-lock if both tuples are in the same bucket.
- differentBuckets := bucket != replyBucket
- if differentBuckets {
- ct.buckets[replyBucket].mu.Lock()
+ // Reap the reply.
+ if replyTupleInserted {
+ // Don't re-lock if both tuples are in the same bucket.
+ if bktID != replyBktID {
+ replyBkt := &ct.buckets[replyBktID]
+ replyBkt.mu.Lock()
+ removeConnFromBucket(replyBkt, tuple)
+ replyBkt.mu.Unlock()
+ } else {
+ removeConnFromBucket(bkt, tuple)
+ }
}
- // We have the buckets locked and can remove both tuples.
- if tuple.direction == dirOriginal {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
- } else {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
- }
- ct.buckets[bucket].tuples.Remove(tuple)
+ bkt.tuples.Remove(tuple)
+ return true
+}
- // Don't re-unlock if both tuples are in the same bucket.
- if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
+// +checklocks:b.mu
+func removeConnFromBucket(b *bucket, tuple *tuple) {
+ if tuple.reply {
+ b.tuples.Remove(&tuple.conn.original)
+ } else {
+ b.tuples.Remove(&tuple.conn.reply)
}
-
- return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,17 +650,22 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
- conn, _ := ct.connForTID(tid)
- if conn == nil {
+ t := ct.connForTID(tid)
+ if t == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip != manipDestination {
+ }
+
+ t.conn.mu.RLock()
+ defer t.conn.mu.RUnlock()
+ if !t.conn.destinationManip {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
- return conn.original.dstAddr, conn.original.dstPort, nil
+ id := t.conn.original.id()
+ return id.dstAddr, id.dstPort, nil
}
diff --git a/pkg/tcpip/stack/conntrack_test.go b/pkg/tcpip/stack/conntrack_test.go
new file mode 100644
index 000000000..fb0645ed1
--- /dev/null
+++ b/pkg/tcpip/stack/conntrack_test.go
@@ -0,0 +1,132 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+)
+
+func TestReap(t *testing.T) {
+ // Initialize conntrack.
+ clock := faketime.NewManualClock()
+ ct := ConnTrack{
+ clock: clock,
+ }
+ ct.init()
+ ct.checkNumTuples(t, 0)
+
+ // Simulate sending a SYN. This will get the connection into conntrack, but
+ // the connection won't be considered established. Thus the timeout for
+ // reaping is unestablishedTimeout.
+ pkt1 := genTCPPacket()
+ pkt1.tuple = ct.getConnOrMaybeInsertNoop(pkt1)
+ // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls
+ // rt.RequiresTXTransportChecksum.
+ var rt Route
+ rt.routeInfo.Loop = PacketLoop
+ if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) {
+ t.Fatal("handlePacket() shouldn't perform any NAT")
+ }
+ ct.checkNumTuples(t, 1)
+
+ // Travel a little into the future and send the same SYN. This should update
+ // lastUsed, but per #6748 didn't.
+ clock.Advance(unestablishedTimeout / 2)
+ pkt2 := genTCPPacket()
+ pkt2.tuple = ct.getConnOrMaybeInsertNoop(pkt2)
+ if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) {
+ t.Fatal("handlePacket() shouldn't perform any NAT")
+ }
+ ct.checkNumTuples(t, 1)
+
+ // Travel farther into the future - enough that failing to update lastUsed
+ // would cause a reaping - and reap the whole table. Make sure the connection
+ // hasn't been reaped.
+ clock.Advance(unestablishedTimeout * 3 / 4)
+ ct.reapEverything()
+ ct.checkNumTuples(t, 1)
+
+ // Travel past unestablishedTimeout to confirm the tuple is gone.
+ clock.Advance(unestablishedTimeout / 2)
+ ct.reapEverything()
+ ct.checkNumTuples(t, 0)
+}
+
+// genTCPPacket returns an initialized IPv4 TCP packet.
+func genTCPPacket() *PacketBuffer {
+ const packetLen = header.IPv4MinimumSize + header.TCPMinimumSize
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: packetLen,
+ })
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
+ tcpHdr := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize))
+ tcpHdr.Encode(&header.TCPFields{
+ SrcPort: 5555,
+ DstPort: 6666,
+ SeqNum: 7777,
+ AckNum: 8888,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ Checksum: 0, // Conntrack doesn't verify the checksum.
+ })
+ ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+ ipHdr.Encode(&header.IPv4Fields{
+ TotalLength: packetLen,
+ Protocol: uint8(header.TCPProtocolNumber),
+ SrcAddr: testutil.MustParse4("1.0.0.1"),
+ DstAddr: testutil.MustParse4("1.0.0.2"),
+ Checksum: 0, // Conntrack doesn't verify the checksum.
+ })
+
+ return pkt
+}
+
+// checkNumTuples checks that there are exactly want tuples tracked by
+// conntrack.
+func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) {
+ t.Helper()
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+
+ var total int
+ for idx := range ct.buckets {
+ ct.buckets[idx].mu.RLock()
+ total += ct.buckets[idx].tuples.Len()
+ ct.buckets[idx].mu.RUnlock()
+ }
+
+ if total != want {
+ t.Fatalf("checkNumTuples: got %d, wanted %d", total, want)
+ }
+}
+
+func (ct *ConnTrack) reapEverything() {
+ var bucket int
+ for {
+ newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */)
+ // We started reaping at bucket 0. If the next bucket isn't after our
+ // current bucket, we've gone through them all.
+ if newBucket <= bucket {
+ break
+ }
+ bucket = newBucket
+ }
+}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index ccb69393b..c2f1f4798 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
- return fwdTestNetDefaultPrefixLen
-}
-
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
@@ -384,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
- if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
// NIC 2 has the link address "b", and added the network address 2.
@@ -397,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
- if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
nic, ok := s.nics[2]
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
index 3a20839da..99e5d2df7 100644
--- a/pkg/tcpip/stack/icmp_rate_limit.go
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -16,6 +16,7 @@ package stack
import (
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
const (
@@ -31,11 +32,41 @@ const (
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
type ICMPRateLimiter struct {
- *rate.Limiter
+ limiter *rate.Limiter
+ clock tcpip.Clock
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
-// at which ICMP messages are generated by the stack.
-func NewICMPRateLimiter() *ICMPRateLimiter {
- return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+// at which ICMP messages are generated by the stack. The returned limiter
+// does not apply limits to any ICMP types by default.
+func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter {
+ return &ICMPRateLimiter{
+ clock: clock,
+ limiter: rate.NewLimiter(icmpLimit, icmpBurst),
+ }
+}
+
+// SetLimit sets a new Limit for the limiter.
+func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) {
+ l.limiter.SetLimitAt(l.clock.Now(), limit)
+}
+
+// Limit returns the maximum overall event rate.
+func (l *ICMPRateLimiter) Limit() rate.Limit {
+ return l.limiter.Limit()
+}
+
+// SetBurst sets a new burst size for the limiter.
+func (l *ICMPRateLimiter) SetBurst(burst int) {
+ l.limiter.SetBurstAt(l.clock.Now(), burst)
+}
+
+// Burst returns the maximum burst size.
+func (l *ICMPRateLimiter) Burst() int {
+ return l.limiter.Burst()
+}
+
+// Allow reports whether one ICMP message may be sent now.
+func (l *ICMPRateLimiter) Allow() bool {
+ return l.limiter.AllowN(l.clock.Now(), 1)
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..fd61387bf 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables(seed uint32) *IPTables {
+func DefaultTables(seed uint32, clock tcpip.Clock) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
@@ -182,7 +182,8 @@ func DefaultTables(seed uint32) *IPTables {
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
- seed: seed,
+ seed: seed,
+ clock: clock,
},
reaperDone: make(chan struct{}, 1),
}
@@ -264,33 +265,125 @@ const (
chainReturn
)
-// Check runs pkt through the rules for hook. It returns true when the packet
-// should continue traversing the network stack and false when it should be
-// dropped.
+// CheckPrerouting performs the prerouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
+ const hook = Prerouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt)
+
+ return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
+}
+
+// CheckInput performs the input hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
+ const hook = Input
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+// CheckForward performs the forward hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+ return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
+}
+
+// CheckOutput performs the output hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
+ const hook = Output
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt)
+
+ return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+}
+
+// CheckPostrouting performs the postrouting hook on the packet.
//
-// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
- if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
+ const hook = Postrouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
return true
}
+
+ ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool {
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ // IPTables only supports IPv4/IPv6.
+ return true
+ }
+
+ it.mu.RLock()
+ defer it.mu.RUnlock()
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
+ return !it.modified
+}
+
+// check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
it.mu.RLock()
defer it.mu.RUnlock()
- if !it.modified {
- return true
- }
-
- // Packets are manipulated only if connection and matching
- // NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
for _, tableID := range priorities {
- // If handlePacket already NATed the packet, we don't need to
- // check the NAT table.
- if tableID == NATID && pkt.NatDone {
+ if t := pkt.tuple; t != nil && tableID == NATID && t.conn.handlePacket(pkt, hook, r) {
continue
}
var table Table
@@ -300,7 +393,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -311,7 +404,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v {
+ switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -327,21 +420,6 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
}
}
- // If this connection should be tracked, try to add an entry for it. If
- // traversing the nat table didn't end in adding an entry,
- // maybeInsertNoop will add a no-op entry for the connection. This is
- // needeed when establishing connections so that the SYN/ACK reply to an
- // outgoing SYN is delivered to the correct endpoint rather than being
- // redirected by a prerouting rule.
- //
- // From the iptables documentation: "If there is no rule, a `null'
- // binding is created: this usually does not map the packet, but exists
- // to ensure we don't map another stream over an existing one."
- if shouldTrack {
- it.connections.maybeInsertNoop(pkt, hook)
- }
-
- // Every table returned Accept.
return true
}
@@ -375,30 +453,46 @@ func (it *IPTables) startReaper(interval time.Duration) {
}()
}
-// CheckPackets runs pkts through the rules for hook and returns a map of packets that
-// should not go forward.
+// CheckOutputPackets performs the output hook on the packets.
//
-// Preconditions:
-// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// * pkt.NetworkHeader is not nil.
+// Returns a map of packets that must be dropped.
//
-// NOTE: unlike the Check API the returned map contains packets that should be
-// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckOutput(pkt, r, outNicName)
+ }, true /* dnat */)
+}
+
+// CheckPostroutingPackets performs the postrouting hook on the packets.
+//
+// Returns a map of packets that must be dropped.
+//
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckPostrouting(pkt, r, addressEP, outNicName)
+ }, false /* dnat */)
+}
+
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool, dnat bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if !pkt.NatDone {
- if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok {
- if drop == nil {
- drop = make(map[*PacketBuffer]struct{})
- }
- drop[pkt] = struct{}{}
+ natDone := &pkt.SNATDone
+ if dnat {
+ natDone = &pkt.DNATDone
+ }
+
+ if ok := f(pkt); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
}
- if pkt.NatDone {
- if natPkts == nil {
- natPkts = make(map[*PacketBuffer]struct{})
- }
- natPkts[pkt] = struct{}{}
+ drop[pkt] = struct{}{}
+ }
+ if *natDone {
+ if natPkts == nil {
+ natPkts = make(map[*PacketBuffer]struct{})
}
+ natPkts[pkt] = struct{}{}
}
}
return drop, natPkts
@@ -407,11 +501,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -428,7 +522,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -454,7 +548,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
@@ -477,16 +571,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr)
+ return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..ef515bdd2 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,10 +79,49 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
+// DNATTarget modifies the destination port/IP of packets.
+type DNATTarget struct {
+ // The new destination address for packets.
+ //
+ // Immutable.
+ Addr tcpip.Address
+
+ // The new destination port for packets.
+ //
+ // Immutable.
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with.
+ //
+ // Immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Prerouting, Output:
+ case Input, Forward, Postrouting:
+ panic(fmt.Sprintf("%s not supported for DNAT", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ return natAction(pkt, hook, r, rt.Port, rt.Addr, true /* dnat */)
+
+}
+
// RedirectTarget redirects the packet to this machine by modifying the
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
@@ -97,7 +136,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -105,18 +144,9 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Packet is already manipulated.
- if pkt.NatDone {
- return RuleAccept, 0
- }
-
- // Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
- return RuleDrop, 0
- }
-
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
+ var address tcpip.Address
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
@@ -125,48 +155,13 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
address = header.IPv6Loopback
}
case Prerouting:
- // No-op, as address is already set correctly.
+ // addressEP is expected to be set for the prerouting hook.
+ address = addressEP.MainAddress().Address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
-
- if hook == Output {
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- udpHeader,
- false, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- rt.Port,
- address,
- )
- } else {
- udpHeader.SetDestinationPort(rt.Port)
- }
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
-
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
- }
-
- return RuleAccept, 0
+ return natAction(pkt, hook, r, rt.Port, address, true /* dnat */)
}
// SNATTarget modifies the source port/IP in the outgoing packets.
@@ -179,8 +174,36 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
+func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) {
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
+ return RuleDrop, 0
+ }
+
+ t := pkt.tuple
+ if t == nil {
+ return RuleDrop, 0
+ }
+
+ // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a
+ // different port.
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ case header.TCPProtocolNumber:
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ default:
+ panic(fmt.Sprintf("unsupported transport protocol = %d", pkt.TransportProtocolNumber))
+ }
+ }
+
+ t.conn.performNAT(pkt, hook, r, port, address, dnat)
+ return RuleAccept, 0
+}
+
// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -188,16 +211,6 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
st.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Packet is already manipulated.
- if pkt.NatDone {
- return RuleAccept, 0
- }
-
- // Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
- return RuleDrop, 0
- }
-
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
@@ -206,37 +219,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
+ return natAction(pkt, hook, r, st.Port, st.Addr, false /* dnat */)
+}
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
+// MasqueradeTarget modifies the source port/IP in the outgoing packets.
+type MasqueradeTarget struct {
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ mt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Postrouting:
+ case Prerouting, Input, Forward, Output:
+ panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ // addressEP is expected to be set for the postrouting hook.
+ ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */)
+ if ep == nil {
+ // No address exists that we can use as a source address.
return RuleDrop, 0
}
- return RuleAccept, 0
+ address := ep.AddressWithPrefix().Address
+ ep.DecRef()
+ return natAction(pkt, hook, r, 0 /* port */, address, false /* dnat */)
}
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 66e5f22ac..b22024667 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -81,17 +81,6 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects v4Tables, v6Tables, and modified.
- mu sync.RWMutex
- // v4Tables and v6tables map tableIDs to tables. They hold builtin
- // tables only, not user tables. mu must be locked for accessing.
- v4Tables [NumTables]Table
- v6Tables [NumTables]Table
- // modified is whether tables have been modified at least once. It is
- // used to elide the iptables performance overhead for workloads that
- // don't utilize iptables.
- modified bool
-
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
@@ -101,6 +90,21 @@ type IPTables struct {
// reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
+
+ mu sync.RWMutex
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables.
+ //
+ // +checklocks:mu
+ v4Tables [NumTables]Table
+ // +checklocks:mu
+ v6Tables [NumTables]Table
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ //
+ // +checklocks:mu
+ modified bool
}
// VisitTargets traverses all the targets of all tables and replaces each with
@@ -352,5 +356,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int)
+ Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 4d5431da1..40b33b6b5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Should get the address immediately since we should not have performed
@@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Make sure the address does not resolve before the resolution time has
@@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet
@@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) {
// Attempting to add the address again should not fail if the address's
// state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
})
}
@@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) {
// Add addresses for each NIC.
addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix1,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err)
}
addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix2,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err)
}
expectDADEvent(nicID2, addr2)
addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix3,
+ }
+ if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err)
}
expectDADEvent(nicID3, addr3)
@@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
continue
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: test.addrs[j].Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{}
@@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addr2,
}
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
// Add the address as a static address before SLAAC tries to add it.
- if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err)
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
@@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) {
}
if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a796942ab..e251e3b24 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -97,6 +97,8 @@ type packetEndpointList struct {
mu sync.RWMutex
// eps is protected by mu, but the contained PacketEndpoint values are not.
+ //
+ // +checklocks:mu
eps []PacketEndpoint
}
@@ -117,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) {
}
}
+func (p *packetEndpointList) len() int {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return len(p.eps)
+}
+
// forEach calls fn with each endpoints in p while holding the read lock on p.
func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
p.mu.RLock()
@@ -157,14 +165,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0
- // Register supported packet and network endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.packetEPs.eps[netProto] = new(packetEndpointList)
- }
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.packetEPs.eps[netNum] = new(packetEndpointList)
-
netEP := netProto.NewEndpoint(nic, nic)
nic.networkEndpoints[netNum] = netEP
@@ -514,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -525,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return &tcpip.ErrNotSupported{}
}
- addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -831,24 +833,9 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
transProto := state.proto
- // TransportHeader is empty only when pkt is an ICMP packet or was reassembled
- // from fragments.
if pkt.TransportHeader().View().IsEmpty() {
- // ICMP packets don't have their TransportHeader fields set yet, parse it
- // here. See icmp/protocol.go:protocol.Parse for a full explanation.
- if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
- // ICMP packets may be longer, but until icmp.Parse is implemented, here
- // we parse it using the minimum size.
- if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
- n.stats.malformedL4RcvdPackets.Increment()
- // We consider a malformed transport packet handled because there is
- // nothing the caller can do.
- return TransportPacketHandled
- }
- } else if !transProto.Parse(pkt) {
- n.stats.malformedL4RcvdPackets.Increment()
- return TransportPacketHandled
- }
+ n.stats.malformedL4RcvdPackets.Increment()
+ return TransportPacketHandled
}
srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
@@ -974,7 +961,8 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
eps, ok := n.packetEPs.eps[netProto]
if !ok {
- return &tcpip.ErrNotSupported{}
+ eps = new(packetEndpointList)
+ n.packetEPs.eps[netProto] = eps
}
eps.add(ep)
@@ -990,6 +978,9 @@ func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
return
}
eps.remove(ep)
+ if eps.len() == 0 {
+ delete(n.packetEPs.eps, netProto)
+ }
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5cb342f78..c8ad93f29 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
-func (*testIPv6Protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 29c22bfd4..c4a4bbd22 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -126,9 +126,13 @@ type PacketBuffer struct {
EgressRoute RouteInfo
GSOOptions GSO
- // NatDone indicates if the packet has been manipulated as per NAT
- // iptables rule.
- NatDone bool
+ // SNATDone indicates if the packet's source has been manipulated as per
+ // iptables NAT table.
+ SNATDone bool
+
+ // DNATDone indicates if the packet's destination has been manipulated as per
+ // iptables NAT table.
+ DNATDone bool
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
@@ -143,6 +147,8 @@ type PacketBuffer struct {
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
+
+ tuple *tuple
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -296,12 +302,14 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
Owner: pk.Owner,
GSOOptions: pk.GSOOptions,
NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
+ DNATDone: pk.DNATDone,
+ SNATDone: pk.SNATDone,
TransportProtocolNumber: pk.TransportProtocolNumber,
PktType: pk.PktType,
NICID: pk.NICID,
RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
NetworkPacketInfo: pk.NetworkPacketInfo,
+ tuple: pk.tuple,
}
}
@@ -329,15 +337,41 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
+ tuple: pk.tuple,
+ }
+ return newPk
+}
+
+// DeepCopyForForwarding creates a deep copy of the packet buffer for
+// forwarding.
+//
+// The returned packet buffer will have the network and transport headers
+// set if the original packet buffer did.
+func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
+ newPk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: reservedHeaderBytes,
+ Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
+ IsForwardedPacket: true,
+ })
+
+ {
+ consumeBytes := pk.NetworkHeader().View().Size()
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
+ }
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- if pk.NatDone {
- newPk.NatDone = true
+
+ {
+ consumeBytes := pk.TransportHeader().View().Size()
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
+ }
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
+
+ newPk.tuple = pk.tuple
+
return newPk
}
@@ -389,13 +423,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// DeleteFront removes count from the beginning of d. It panics if count >
-// d.Size(). All backing storage references after the front of the d are
-// invalidated.
-func (d PacketData) DeleteFront(count int) {
- if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
- panic("count > d.Size()")
+// Consume is the same as PullUp except that is additionally consumes the
+// returned bytes. Subsequent PullUp or Consume will not return these bytes.
+func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) {
+ v, ok := d.PullUp(size)
+ if ok {
+ d.pk.consumed += size
}
+ return v, ok
}
// CapLength reduces d to at most length bytes.
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 87b023445..c376ed1a1 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) {
}
}
-func TestPacketBufferClone(t *testing.T) {
- data := concatViews(makeView(20), makeView(30), makeView(40))
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
- })
-
- bytesToDelete := 30
- originalSize := data.Size()
-
- clonedPks := []*PacketBuffer{
- pk.Clone(),
- pk.CloneToInbound(),
- }
- pk.Data().DeleteFront(bytesToDelete)
- if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
- t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
- }
- for _, clonedPk := range clonedPks {
- if got := clonedPk.Data().Size(); got != originalSize {
- t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
- }
- }
-}
-
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
@@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // DeleteFront
+ // Consume.
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
+ v, ok := pkt.Data().Consume(n)
+ if !ok {
+ t.Fatalf("Consume failed")
+ }
+ if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) {
+ t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want)
+ }
checkData(t, pkt, []byte(tc.data)[n:])
})
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 113baaaae..31b3a554d 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int
const (
// CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
+ // endpoint for new connections with no local address.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
// FirstPrimaryEndpoint indicates the endpoint should be the first
@@ -332,6 +331,19 @@ const (
NeverPrimaryEndpoint
)
+func (peb PrimaryEndpointBehavior) String() string {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return "CanBePrimaryEndpoint"
+ case FirstPrimaryEndpoint:
+ return "FirstPrimaryEndpoint"
+ case NeverPrimaryEndpoint:
+ return "NeverPrimaryEndpoint"
+ default:
+ panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb))
+ }
+}
+
// AddressConfigType is the method used to add an address.
type AddressConfigType int
@@ -351,6 +363,14 @@ const (
AddressConfigSlaacTemp
)
+// AddressProperties contains additional properties that can be configured when
+// adding an address.
+type AddressProperties struct {
+ PEB PrimaryEndpointBehavior
+ ConfigType AddressConfigType
+ Deprecated bool
+}
+
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
@@ -457,7 +477,7 @@ type AddressableEndpoint interface {
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
@@ -685,9 +705,6 @@ type NetworkProtocol interface {
// than this targeted at this protocol.
MinimumPacketSize() int
- // DefaultPrefixLen returns the protocol's default prefix length.
- DefaultPrefixLen() int
-
// ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cb741e540..a05fd7036 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -238,7 +238,7 @@ type Options struct {
// DefaultIPTables is an optional iptables rules constructor that is called
// if IPTables is nil. If both fields are nil, iptables will allow all
// traffic.
- DefaultIPTables func(uint32) *IPTables
+ DefaultIPTables func(seed uint32, clock tcpip.Clock) *IPTables
// SecureRNG is a cryptographically secure random number generator.
SecureRNG io.Reader
@@ -358,7 +358,7 @@ func New(opts Options) *Stack {
if opts.DefaultIPTables == nil {
opts.DefaultIPTables = DefaultTables
}
- opts.IPTables = opts.DefaultIPTables(seed)
+ opts.IPTables = opts.DefaultIPTables(seed, clock)
}
opts.NUDConfigs.resetInvalidFields()
@@ -375,7 +375,7 @@ func New(opts Options) *Stack {
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
+ icmpRateLimiter: NewICMPRateLimiter(clock),
seed: seed,
nudConfigs: opts.NUDConfigs,
uniqueIDGenerator: opts.UniqueID,
@@ -916,46 +916,9 @@ type NICStateFlags struct {
Loopback bool
}
-// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
- return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
-// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
- ap := tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: addr,
- }
- return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
-}
-
-// AddProtocolAddress adds a new network-layer protocol address to the
-// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
- return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify
-// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return &tcpip.ErrUnknownProtocol{}
- }
- return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb)
-}
-
-// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
-// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+// AddProtocolAddress adds an address to the specified NIC, possibly with extra
+// properties.
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -964,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return &tcpip.ErrUnknownNICID{}
}
- return nic.addAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, properties)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1902,12 +1865,6 @@ const (
// ParsePacketBufferTransport parses the provided packet buffer's transport
// header.
func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
- // ICMP packets don't have their TransportHeader fields set yet, parse it
- // here. See icmp/protocol.go:protocol.Parse for a full explanation.
- if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
- return ParsedOK
- }
-
pkt.TransportProtocolNumber = protocol
// Parse the transport header if present.
state, ok := s.transportProtocols[protocol]
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 3089c0ef4..f5a35eac4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().Consume(fakeNetHeaderLen)
if !ok {
return
}
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]),
// Nothing checks the error.
nil, /* transport error */
pkt,
@@ -158,8 +155,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ transProtoNum := tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset])
+ switch err := f.proto.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(transProtoNum, pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -221,6 +228,8 @@ func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {}
// number of packets sent and received via endpoints of this protocol. The index
// where packets are added is given by the packet's destination address MOD 10.
type fakeNetworkProtocol struct {
+ stack *stack.Stack
+
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
@@ -234,10 +243,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (*fakeNetworkProtocol) DefaultPrefixLen() int {
- return fakeDefaultPrefixLen
-}
-
func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
return f.packetCount[int(intfAddr)%len(f.packetCount)]
}
@@ -306,8 +311,8 @@ func (f *fakeNetworkEndpoint) SetForwarding(v bool) {
f.mu.forwarding = v
}
-func fakeNetFactory(*stack.Stack) stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
+func fakeNetFactory(s *stack.Stack) stack.NetworkProtocol {
+ return &fakeNetworkProtocol{stack: s}
}
// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
@@ -349,12 +354,26 @@ func TestNetworkReceive(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -517,8 +536,15 @@ func TestNetworkSend(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Make sure that the link-layer endpoint received the outbound packet.
@@ -538,12 +564,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -551,12 +591,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -812,8 +866,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
ep2 := channel.New(1, defaultMTU, "")
@@ -821,8 +882,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr2,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
// Set a route table that sends all packets with odd destination
@@ -978,12 +1046,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -991,12 +1073,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -1058,8 +1154,15 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1108,8 +1211,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1242,8 +1352,15 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1270,8 +1387,8 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1310,8 +1427,8 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1453,8 +1570,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
@@ -1510,8 +1634,15 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
@@ -1633,8 +1764,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
- if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
+ if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
@@ -1678,13 +1809,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
t.Fatalf("CreateNIC failed: %s", err)
}
nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1726,7 +1857,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1},
},
rt...,
)
@@ -1808,8 +1939,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: anyAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
@@ -1886,22 +2024,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Add an address and in case of a primary one include a
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
+ properties := stack.AddressProperties{PEB: behavior}
if behavior == stack.CanBePrimaryEndpoint {
protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: addrLen * 8,
- },
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: address.WithPrefix(),
}
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
}
}
@@ -1996,8 +2139,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
PrefixLen: tc.prefixLen,
},
}
- if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
- t.Fatal("AddProtocolAddress failed:", err)
+ if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -2047,33 +2190,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
}
-func TestAddAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
- for _, addrLen := range []int{4, 16} {
- address := addrGen.next(addrLen)
- if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
@@ -2084,96 +2200,43 @@ func TestAddProtocolAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- var addrGen addressGenerator
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
addrLenRange := []int{4, 16}
behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp}
+ deprecatedRange := []bool{false, true}
+ wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange))
var addrGen addressGenerator
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- for _, behavior := range behaviorRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ for _, configType := range configTypeRange {
+ for _, deprecated := range deprecatedRange {
+ address := addrGen.next(addrLen)
+ properties := stack.AddressProperties{
+ PEB: behavior,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err)
+ }
+ wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
+ })
}
- expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
+ verifyAddresses(t, wantAddresses, gotAddresses)
}
func TestCreateNICWithOptions(t *testing.T) {
@@ -2290,8 +2353,15 @@ func TestNICStats(t *testing.T) {
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed: ", err)
}
- if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nic.addr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err)
}
{
@@ -2735,8 +2805,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// be returned by a call to GetMainNICAddress;
// else, it should.
const address1 = tcpip.Address("\x01")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ properties := stack.AddressProperties{PEB: pi}
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err)
}
addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
if err != nil {
@@ -2785,16 +2863,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
const address3 = tcpip.Address("\x03")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
-
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address3,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: ps}
+ if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
@@ -3096,8 +3189,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
for _, a := range test.nicAddrs {
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
- t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: a.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -3203,8 +3300,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// The NIC should have joined addr1's solicited node multicast address.
@@ -3359,8 +3460,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
PrefixLen: 128,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
// Address should be in the list of all addresses.
@@ -3687,8 +3788,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
@@ -3750,8 +3851,8 @@ func TestResolveWith(t *testing.T) {
PrefixLen: 24,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
@@ -3792,8 +3893,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3881,8 +3989,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -3990,44 +4098,44 @@ func TestFindRouteWithForwarding(t *testing.T) {
)
type netCfg struct {
- proto tcpip.NetworkProtocolNumber
- factory stack.NetworkProtocolFactory
- nic1Addr tcpip.Address
- nic2Addr tcpip.Address
- remoteAddr tcpip.Address
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1AddrWithPrefix tcpip.AddressWithPrefix
+ nic2AddrWithPrefix tcpip.AddressWithPrefix
+ remoteAddr tcpip.Address
}
fakeNetCfg := netCfg{
- proto: fakeNetNumber,
- factory: fakeNetFactory,
- nic1Addr: nic1Addr,
- nic2Addr: nic2Addr,
- remoteAddr: remoteAddr,
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen},
+ nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen},
+ remoteAddr: remoteAddr,
}
globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: llAddr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: globalIPv6Addr1,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: llAddr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: globalIPv6Addr1,
}
ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: llAddr1,
- remoteAddr: llAddr2,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: llAddr1.WithPrefix(),
+ remoteAddr: llAddr2,
}
ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
}
tests := []struct {
@@ -4036,8 +4144,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg netCfg
forwardingEnabled bool
- addrNIC tcpip.NICID
- localAddr tcpip.Address
+ addrNIC tcpip.NICID
+ localAddrWithPrefix tcpip.AddressWithPrefix
findRouteErr tcpip.Error
dependentOnForwarding bool
@@ -4047,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4056,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4065,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4074,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4083,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4092,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4101,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4110,7 +4218,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4118,7 +4226,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4126,7 +4234,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4134,7 +4242,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4142,7 +4250,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: true,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4166,7 +4274,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on different NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4174,7 +4282,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4182,7 +4290,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4190,7 +4298,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4198,7 +4306,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4206,7 +4314,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4214,7 +4322,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4222,7 +4330,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4230,7 +4338,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4238,7 +4346,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4246,7 +4354,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4268,12 +4376,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
}
- if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic1AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
- if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic2AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
@@ -4282,20 +4398,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
- r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
- t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
return
}
- if r.LocalAddress() != test.localAddr {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr)
+ if r.LocalAddress() != test.localAddrWithPrefix.Address {
+ t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address)
}
if r.RemoteAddress() != test.netCfg.remoteAddr {
t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr)
@@ -4318,8 +4434,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if !ok {
t.Fatal("packet not sent through ep2")
}
- if pkt.Route.LocalAddress != test.localAddr {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address)
}
if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index dc7289441..a941091b0 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -289,6 +289,12 @@ type TCPSenderState struct {
// RACKState holds the state related to RACK loss detection algorithm.
RACKState TCPRACKState
+
+ // RetransmitTS records the timestamp used to detect spurious recovery.
+ RetransmitTS uint32
+
+ // SpuriousRecovery indicates if the sender entered recovery spuriously.
+ SpuriousRecovery bool
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 824cf6526..3474c292a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -32,11 +32,13 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
- // mu protects all fields of the transportEndpoints.
- mu sync.RWMutex
+ mu sync.RWMutex
+ // +checklocks:mu
endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
+ //
+ // +checklocks:mu
rawEndpoints []RawTransportEndpoint
}
@@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
-// Preconditions: eps.mu must be locked.
+// +checklocksread:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
-// Preconditions: eps.mu must be locked.
+// +checklocksread:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
-// Preconditions: eps.mu must be locked.
+// +checklocksread:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo
}
type endpointsByNIC struct {
- mu sync.RWMutex
- endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
+
+ mu sync.RWMutex
+ // +checklocks:mu
+ endpoints map[tcpip.NICID]*multiPortEndpoint
}
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
@@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
- transEP := selectEndpoint(id, mpep, epsByNIC.seed)
+ transEP := mpep.selectEndpoint(id, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
@@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt)
+ mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
@@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
demux *transportDemuxer
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
+ flags ports.FlagCounter
+
+ mu sync.RWMutex `state:"nosave"`
// endpoints stores the transport endpoints in the order in which they
// were bound. This is required for UDP SO_REUSEADDR.
+ //
+ // +checklocks:mu
endpoints []TransportEndpoint
- flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpoints) == 1 {
- return mpep.endpoints[0]
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ if len(ep.endpoints) == 1 {
+ return ep.endpoints[0]
}
- if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
- return mpep.endpoints[len(mpep.endpoints)-1]
+ if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
+ return ep.endpoints[len(ep.endpoints)-1]
}
payload := []byte{
@@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
- return mpep.endpoints[idx]
+ idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
+ return ep.endpoints[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
@@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
}
}
- ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ ep := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
return ep
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 45b09110d..cd3a8c25a 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -35,7 +35,7 @@ import (
const (
testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
@@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: testDstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 839178809..51870d03f 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -331,8 +331,11 @@ func (*fakeTransportProtocol) Wait() {}
// Parse implements TransportProtocol.Parse.
func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
- _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen)
- return ok
+ if _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen); ok {
+ pkt.TransportProtocolNumber = fakeTransNumber
+ return true
+ }
+ return false
}
func fakeTransFactory(s *stack.Stack) stack.TransportProtocol {
@@ -357,8 +360,15 @@ func TestTransportReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -428,8 +438,15 @@ func TestTransportControlReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -497,8 +514,15 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 55683b4fb..460a6afaf 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -19,7 +19,7 @@
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
-// adding network addresses (via calls to Stack.AddAddress()), and
+// adding network addresses (via calls to Stack.AddProtocolAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
@@ -423,9 +423,9 @@ type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
@@ -451,6 +451,12 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // IPv6PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo IPv6PacketInfo
+
// HasOriginalDestinationAddress indicates whether OriginalDstAddress is
// set.
HasOriginalDstAddress bool
@@ -465,10 +471,10 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns KUID of the packet.
+ // KUID returns KUID of the packet.
KUID() uint32
- // GID returns KGID of the packet.
+ // KGID returns KGID of the packet.
KGID() uint32
}
@@ -1164,6 +1170,14 @@ type IPPacketInfo struct {
DestinationAddr Address
}
+// IPv6PacketInfo is the message structure for IPV6_PKTINFO.
+//
+// +stateify savable
+type IPv6PacketInfo struct {
+ Addr Address
+ NIC NICID
+}
+
// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max send buffer sizes.
type SendBufferSizeOption struct {
@@ -1231,11 +1245,11 @@ type Route struct {
// String implements the fmt.Stringer interface.
func (r Route) String() string {
var out strings.Builder
- fmt.Fprintf(&out, "%s", r.Destination)
+ _, _ = fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
- fmt.Fprintf(&out, " via %s", r.Gateway)
+ _, _ = fmt.Fprintf(&out, " via %s", r.Gateway)
}
- fmt.Fprintf(&out, " nic %d", r.NIC)
+ _, _ = fmt.Fprintf(&out, " nic %d", r.NIC)
return out.String()
}
@@ -1255,6 +1269,8 @@ type TransportProtocolNumber uint32
type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
+//
+// +stateify savable
type StatCounter struct {
count atomicbitops.AlignedAtomicUint64
}
@@ -1270,7 +1286,7 @@ func (s *StatCounter) Decrement() {
}
// Value returns the current value of the counter.
-func (s *StatCounter) Value(name ...string) uint64 {
+func (s *StatCounter) Value(...string) uint64 {
return s.count.Load()
}
@@ -1849,6 +1865,10 @@ type TCPStats struct {
// SegmentsAckedWithDSACK is the number of segments acknowledged with
// DSACK.
SegmentsAckedWithDSACK *StatCounter
+
+ // SpuriousRecovery is the number of times the connection entered loss
+ // recovery spuriously.
+ SpuriousRecovery *StatCounter
}
// UDPStats collects UDP-specific stats.
@@ -1981,6 +2001,8 @@ type Stats struct {
}
// ReceiveErrors collects packet receive errors within transport endpoint.
+//
+// +stateify savable
type ReceiveErrors struct {
// ReceiveBufferOverflow is the number of received packets dropped
// due to the receive buffer being full.
@@ -1998,8 +2020,10 @@ type ReceiveErrors struct {
ChecksumErrors StatCounter
}
-// SendErrors collects packet send errors within the transport layer for
-// an endpoint.
+// SendErrors collects packet send errors within the transport layer for an
+// endpoint.
+//
+// +stateify savable
type SendErrors struct {
// SendToNetworkFailed is the number of packets failed to be written to
// the network endpoint.
@@ -2010,6 +2034,8 @@ type SendErrors struct {
}
// ReadErrors collects segment read errors from an endpoint read call.
+//
+// +stateify savable
type ReadErrors struct {
// ReadClosed is the number of received packet drops because the endpoint
// was shutdown for read.
@@ -2025,6 +2051,8 @@ type ReadErrors struct {
}
// WriteErrors collects packet write errors from an endpoint write call.
+//
+// +stateify savable
type WriteErrors struct {
// WriteClosed is the number of packet drops because the endpoint
// was shutdown for write.
@@ -2040,6 +2068,8 @@ type WriteErrors struct {
}
// TransportEndpointStats collects statistics about the endpoint.
+//
+// +stateify savable
type TransportEndpointStats struct {
// PacketsReceived is the number of successful packet receives.
PacketsReceived StatCounter
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/tcpip_state.go
index 529e02a07..1953e24a1 100644
--- a/pkg/tcpip/stack/iptables_state.go
+++ b/pkg/tcpip/tcpip_state.go
@@ -1,4 +1,4 @@
-// Copyright 2020 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,29 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+package tcpip
import (
"time"
)
-// +stateify savable
-type unixTime struct {
- second int64
- nano int64
+func (c *ControlMessages) saveTimestamp() int64 {
+ return c.Timestamp.UnixNano()
}
-// saveLastUsed is invoked by stateify.
-func (cn *conn) saveLastUsed() unixTime {
- return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
-}
-
-// loadLastUsed is invoked by stateify.
-func (cn *conn) loadLastUsed(unix unixTime) {
- cn.lastUsed = time.Unix(unix.second, unix.nano)
-}
-
-// beforeSave is invoked by stateify.
-func (ct *ConnTrack) beforeSave() {
- ct.mu.Lock()
+func (c *ControlMessages) loadTimestamp(nsec int64) {
+ c.Timestamp = time.Unix(0, nsec)
}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 181ef799e..99f4d4d0e 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -34,12 +34,16 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
"//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -139,3 +143,25 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "istio_test",
+ size = "small",
+ srcs = ["istio_test.go"],
+ deps = [
+ "//pkg/context",
+ "//pkg/rand",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/pipe",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 92fa6257d..6e1d4720d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) {
addr: utils.RouterNIC2IPv6Addr,
},
} {
- if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index f9ab7d0af..957a779bf 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -15,19 +15,24 @@
package iptables_test
import (
+ "bytes"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type inputIfNameMatcher struct {
@@ -49,10 +54,10 @@ const (
nicName = "nic1"
anotherNicName = "nic2"
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = "\x0a\x00\x00\x01"
- dstAddrV4 = "\x0a\x00\x00\x02"
- srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
+ dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
+ srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
payloadSize = 20
)
@@ -66,8 +71,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: dstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -82,8 +91,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: dstAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -601,11 +614,19 @@ func TestIPTableWritePackets(t *testing.T) {
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: srcAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
+ }
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: srcAddrV4.WithPrefix(),
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -856,11 +877,19 @@ func TestForwardingHook(t *testing.T) {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1037,22 +1066,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
}
e2 := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1132,3 +1161,621 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
})
}
}
+
+func TestNAT(t *testing.T) {
+ const listenPort uint16 = 8080
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.FullAddress
+ serverReadableCH chan struct{}
+ serverConnectAddr tcpip.Address
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+ clientConnectAddr tcpip.FullAddress
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+ t.Cleanup(ep.Close)
+
+ return ep, ch
+ }
+
+ setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+ table := ipt.GetTable(stack.NATID, ipv6)
+ ruleIdx := table.BuiltinChains[hook]
+ table.Rules[ruleIdx].Filter = filter
+ table.Rules[ruleIdx].Target = target
+ // Make sure the packet is not dropped by the next rule.
+ table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+
+ setupDNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ setupNAT(
+ t,
+ s,
+ netProto,
+ stack.Prerouting,
+ stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ target)
+ }
+
+ setupSNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ setupNAT(
+ t,
+ s,
+ netProto,
+ stack.Postrouting,
+ stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ target)
+ }
+
+ type natType struct {
+ name string
+ setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address)
+ }
+
+ snatTypes := []natType{
+ {
+ name: "SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address) {
+ t.Helper()
+
+ setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
+ },
+ },
+ {
+ name: "Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
+ t.Helper()
+
+ setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ }
+ dnatTypes := []natType{
+ {
+ name: "Redirect",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
+ t.Helper()
+
+ setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: listenPort})
+ },
+ },
+ {
+ name: "DNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort})
+ },
+ },
+ }
+
+ setupTwiceNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, snatTarget stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: snatTarget,
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
+ },
+ }
+
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+ twiceNATTypes := []natType{
+ {
+ name: "DNAT-Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ {
+ name: "DNAT-SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ // Setups up the stacks in such a way that:
+ //
+ // - Host2 is the client for all tests.
+ // - When performing SNAT only:
+ // + Host1 is the server.
+ // + NAT will transform client-originating packets' source addresses to
+ // the router's NIC1's address before reaching Host1.
+ // - When performing DNAT only:
+ // + Router is the server.
+ // + Client will send packets directed to Host1.
+ // + NAT will transform client-originating packets' destination addresses
+ // to the router's NIC2's address.
+ // - When performing Twice-NAT:
+ // + Host1 is the server.
+ // + Client will send packets directed to router's NIC2.
+ // + NAT will transform client originating packets' destination addresses
+ // to Host1's address.
+ // + NAT will transform client-originating packets' source addresses to
+ // the router's NIC1's address before reaching Host1.
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ natTypes []natType
+ }{
+ {
+ name: "IPv4 SNAT",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: snatTypes,
+ },
+ {
+ name: "IPv4 DNAT",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ // If we are performing DNAT, then the packet will be redirected
+ // to the router.
+ listenerStack := routerStack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address
+ // DNAT will update the destination port to what the server is
+ // bound to.
+ clientConnectPort := serverAddr.Port + 1
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: dnatTypes,
+ },
+ {
+ name: "IPv4 Twice-NAT",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: twiceNATTypes,
+ },
+ {
+ name: "IPv6 SNAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: snatTypes,
+ },
+ {
+ name: "IPv6 DNAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ // If we are performing DNAT, then the packet will be redirected
+ // to the router.
+ listenerStack := routerStack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address
+ // DNAT will update the destination port to what the server is
+ // bound to.
+ clientConnectPort := serverAddr.Port + 1
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: dnatTypes,
+ },
+ {
+ name: "IPv6 Twice-NAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: twiceNATTypes,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr tcpip.Error
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
+ }
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
+ t.Helper()
+
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
+ }
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ return newEP, newCH
+ }
+ },
+ needRemoteAddr: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ for _, natType := range test.natTypes {
+ t.Run(natType.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+ natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr)
+
+ if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff)
+ }
+ }
+ serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ serverConnectAddr.Port = addr.Port
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, serverConnectAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go
new file mode 100644
index 000000000..95d994ef8
--- /dev/null
+++ b/pkg/tcpip/tests/integration/istio_test.go
@@ -0,0 +1,365 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package istio_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strconv"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+// testContext encapsulates the state required to run tests that simulate
+// an istio like environment.
+//
+// A diagram depicting the setup is shown below.
+// +-----------------------------------------------------------------------+
+// | +-------------------------------------------------+ |
+// | + ----------+ | + -----------------+ PROXY +----------+ | |
+// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | |
+// | + ----------+ | + -----------------+ +----------+ | | |
+// | | -------|-------------+ +----------+ | | |
+// | | | | | proxyEP |-+ | |
+// | +-----redirect | +----------+ | |
+// | + ------------+---|------+---+ |
+// | | |
+// | Local Stack. | |
+// +-------------------------------------------------------|---------------+
+// |
+// +-----------------------------------------------------------------------+
+// | remoteStack | |
+// | +-------------SYN ---------------| |
+// | | | |
+// | +-------------------|--------------------------------|-_---+ |
+// | | + -----------------+ + ----------+ | | |
+// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | |
+// | | + -----------------+ + ----------+ | |
+// | | Remote HTTP Server | |
+// | +----------------------------------------------------------+ |
+// +-----------------------------------------------------------------------+
+//
+type testContext struct {
+ // localServerListener is the listening port for the server which will proxy
+ // all traffic to the remote EP.
+ localServerListener *gonet.TCPListener
+
+ // remoteListenListener is the remote listening endpoint that will receive
+ // connections from server.
+ remoteServerListener *gonet.TCPListener
+
+ // localStack is the stack used to create client/server endpoints and
+ // also the stack on which we install NAT redirect rules.
+ localStack *stack.Stack
+
+ // remoteStack is the stack that represents a *remote* server.
+ remoteStack *stack.Stack
+
+ // defaultResponse is the response served by the HTTP server for all GET
+ defaultResponse []byte
+
+ // requests. wg is used to wait for HTTP server and Proxy to terminate before
+ // returning from cleanup.
+ wg sync.WaitGroup
+}
+
+func (ctx *testContext) cleanup() {
+ ctx.localServerListener.Close()
+ ctx.localStack.Close()
+ ctx.remoteServerListener.Close()
+ ctx.remoteStack.Close()
+ ctx.wg.Wait()
+}
+
+const (
+ localServerPort = 8080
+ remoteServerPort = 9090
+)
+
+var (
+ localIPv4Addr1 = testutil.MustParse4("10.0.0.1")
+ localIPv4Addr2 = testutil.MustParse4("10.0.0.2")
+ loopbackIPv4Addr = testutil.MustParse4("127.0.0.1")
+ remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3")
+)
+
+func newTestContext(t *testing.T) *testContext {
+ t.Helper()
+ localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */)
+
+ localStack := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ HandleLocal: true,
+ })
+
+ remoteStack := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ HandleLocal: true,
+ })
+
+ // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to
+ // loopback address + specified port.
+ loopbackNIC := loopback.New()
+ const loopbackNICID = tcpip.NICID(1)
+ if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil {
+ t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err)
+ }
+ loopbackAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: loopbackIPv4Addr.WithPrefix(),
+ }
+ if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err)
+ }
+
+ // Create linked NICs that connects the local and remote stack.
+ const localNICID = tcpip.NICID(2)
+ const remoteNICID = tcpip.NICID(3)
+ if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil {
+ t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err)
+ }
+ if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil {
+ t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err)
+ }
+
+ for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} {
+ localProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err)
+ }
+ }
+
+ remoteProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: remoteIPv4Addr1.WithPrefix(),
+ }
+ if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err)
+ }
+
+ // Setup route table for local and remote stacks.
+ localStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4LoopbackSubnet,
+ NIC: loopbackNICID,
+ },
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: localNICID,
+ },
+ })
+ remoteStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: remoteNICID,
+ },
+ })
+
+ const netProto = ipv4.ProtocolNumber
+ localServerAddress := tcpip.FullAddress{
+ Port: localServerPort,
+ }
+
+ localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto)
+ if err != nil {
+ t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err)
+ }
+
+ remoteServerAddress := tcpip.FullAddress{
+ Port: remoteServerPort,
+ }
+ remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto)
+ if err != nil {
+ t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err)
+ }
+
+ // Initialize a random default response served by the HTTP server.
+ defaultResponse := make([]byte, 512<<10)
+ if _, err := rand.Read(defaultResponse); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+
+ tc := &testContext{
+ localServerListener: localServerListener,
+ remoteServerListener: remoteServerListener,
+ localStack: localStack,
+ remoteStack: remoteStack,
+ defaultResponse: defaultResponse,
+ }
+
+ tc.startServers(t)
+ return tc
+}
+
+func (ctx *testContext) startServers(t *testing.T) {
+ ctx.wg.Add(1)
+ go func() {
+ defer ctx.wg.Done()
+ ctx.startHTTPServer()
+ }()
+ ctx.wg.Add(1)
+ go func() {
+ defer ctx.wg.Done()
+ ctx.startTCPProxyServer(t)
+ }()
+}
+
+func (ctx *testContext) startTCPProxyServer(t *testing.T) {
+ t.Helper()
+ for {
+ conn, err := ctx.localServerListener.Accept()
+ if err != nil {
+ t.Logf("terminating local proxy server: %s", err)
+ return
+ }
+ // Start a goroutine to handle this inbound connection.
+ go func() {
+ remoteServerAddr := tcpip.FullAddress{
+ Addr: remoteIPv4Addr1,
+ Port: remoteServerPort,
+ }
+ localServerAddr := tcpip.FullAddress{
+ Addr: localIPv4Addr2,
+ }
+ serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err)
+ return
+ }
+ proxy(conn, serverConn)
+ t.Logf("proxying completed")
+ }()
+ }
+}
+
+// proxy transparently proxies the TCP payload from conn1 to conn2
+// and vice versa.
+func proxy(conn1, conn2 net.Conn) {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ io.Copy(conn2, conn1)
+ conn1.Close()
+ conn2.Close()
+ }()
+ wg.Add(1)
+ go func() {
+ io.Copy(conn1, conn2)
+ conn1.Close()
+ conn2.Close()
+ }()
+ wg.Wait()
+}
+
+func (ctx *testContext) startHTTPServer() {
+ handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(ctx.defaultResponse))
+ })
+ s := &http.Server{
+ Handler: handlerFunc,
+ }
+ s.Serve(ctx.remoteServerListener)
+}
+
+func TestOutboundNATRedirect(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+
+ // Install an IPTable rule to redirect all TCP traffic with the sourceIP of
+ // localIPv4Addr1 to the tcp proxy port.
+ ipt := ctx.localStack.IPTables()
+ tbl := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := tbl.BuiltinChains[stack.Output]
+ tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ Protocol: tcp.ProtocolNumber,
+ CheckProtocol: true,
+ Src: localIPv4Addr1,
+ SrcMask: tcpip.Address("\xff\xff\xff\xff"),
+ }
+ tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{
+ Port: localServerPort,
+ NetworkProtocol: ipv4.ProtocolNumber,
+ }
+ tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err)
+ }
+
+ dialFunc := func(protocol, address string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err)
+ }
+
+ remoteServerIP := net.ParseIP(host)
+ remoteServerPort, err := strconv.Atoi(port)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err)
+ }
+ remoteAddress := tcpip.FullAddress{
+ Addr: tcpip.Address(remoteServerIP.To4()),
+ Port: uint16(remoteServerPort),
+ }
+
+ // Dial with an explicit source address bound so that the redirect rule will
+ // be able to correctly redirect these packets.
+ localAddr := tcpip.FullAddress{Addr: localIPv4Addr1}
+ return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" {
+ t.Fatalf("unexpected response (-want +got): \n %s", diff)
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 27caa0c28..95ddd8ec3 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
@@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.incomingAddr,
}
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
}
// Set up endpoint through which we will attempt to forward packets.
@@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.outgoingAddr,
}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index b2008f0b2..f33223e79 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
+ }
+ v6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
}
if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ipv4Loopback,
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if test.forwarding {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 2d0a6e6a7..7753e7d6e 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
@@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
// Set the route table so that UDP can find a NIC that is
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index ac3c703d4..422eb8408 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) {
// request/reply packets.
icmpDataOffset = 8
)
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
+ ipv4Loopback := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse4("127.0.0.1"),
+ PrefixLen: 8,
+ }
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
@@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
icmpBuf func(*testing.T) buffer.View
expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
@@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback,
+ localAddr: header.IPv6Loopback.WithPrefix(),
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
@@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr.Address,
+ localAddr: utils.Ipv4Addr,
icmpBuf: ipv4ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr.Address,
+ localAddr: utils.Ipv6Addr,
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ if len(test.localAddr.Address) != 0 {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.netProto,
+ AddressWithPrefix: test.localAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) {
}
defer ep.Close()
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
if err := ep.Connect(connAddr); err != test.expectedConnectErr {
t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
}
@@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) {
if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
- if rr.RemoteAddr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ if rr.RemoteAddr.Addr != test.localAddr.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
}
test.checkLinkEndpoint(t, e)
@@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) {
}
if subTest.addAddress {
- if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
}
- if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
}
}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 2e6ae55ea..c69410859 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -40,6 +40,14 @@ const (
Host2NICID = 4
)
+// Common NIC names used by tests.
+const (
+ Host1NICName = "host1NIC"
+ RouterNIC1Name = "routerNIC1"
+ RouterNIC2Name = "routerNIC2"
+ Host2NICName = "host2NIC"
+)
+
// Common link addresses used by tests.
const (
LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
@@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host1NICName}
+ if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
+ t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC1Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC2Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
+ }
}
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host2NICName}
+ if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
+ t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
+ }
}
if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -231,29 +251,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index bbc0e3ecc..4718ec4ec 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 00497bf07..995f58616 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,7 @@
package icmp
import (
+ "fmt"
"io"
"time"
@@ -24,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -35,15 +38,6 @@ type icmpPacket struct {
receivedAt time.Time `state:".(int64)"`
}
-type endpointState int
-
-const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
-)
-
// endpoint represents an ICMP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -51,14 +45,17 @@ const (
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -70,38 +67,23 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
- state endpointState
- route *stack.Route `state:"manual"`
- ttl uint8
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
+ ident uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
- state: stateInitial,
uniqueID: s.UniqueID(),
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ ep.net.Init(s, netProto, transProto, &ep.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -128,35 +110,40 @@ func (e *endpoint) Abort() {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
- e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
- case stateBound, stateConnected:
- bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
- }
-
- // Close the receive list and drain it.
- e.rcvMu.Lock()
- e.rcvClosed = true
- e.rcvBufSize = 0
- for !e.rcvList.Empty() {
- p := e.rcvList.Front()
- e.rcvList.Remove(p)
- }
- e.rcvMu.Unlock()
+ notify := func() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ return false
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice()))
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
+ e.net.Shutdown()
+ e.net.Close()
- // Update the state.
- e.state = stateClosed
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
- e.mu.Unlock()
+ return true
+ }()
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ if notify {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ }
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
@@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {}
// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -193,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: p.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -213,14 +200,13 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.state {
- case stateInitial:
- case stateConnected:
+// +checklocksread:e.mu
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
-
- case stateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
- return 0, err
+ return network.WriteContext{}, 0, err
}
if !retry {
@@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
}
- route := e.route
- if to != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := to.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
- dst, netProto, err := e.checkV4MappedLocked(*to)
- if err != nil {
- return 0, err
- }
-
- // Find the endpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return 0, err
- }
- defer r.Release()
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ return ctx, e.ident, err
+}
- route = r
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ ctx, ident, err := e.prepareForWrite(opts)
+ if err != nil {
+ return 0, err
}
+ defer ctx.Release()
// TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
@@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return 0, &tcpip.ErrBadBuffer{}
}
- var err tcpip.Error
- switch e.NetProto {
+ switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
+ if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
case header.IPv6ProtocolNumber:
- err = send6(route, e.ID.LocalPort, v, e.ttl)
- }
-
- if err != nil {
- return 0, err
+ if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
+ default:
+ panic(fmt.Sprintf("unhandled network protocol = %d", netProto))
}
return int64(len(v)), nil
@@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool {
return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt sets a socket option.
-func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
- return nil
+// SetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ return e.net.SetSockOpt(opt)
}
-// SetSockOptInt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- }
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.rcvMu.Lock()
- v := int(e.ttl)
- e.rcvMu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+// GetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
+func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength),
})
- pkt.Owner = owner
icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
@@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
-
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
-
pkt.Data().AppendView(data)
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V4.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
+func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength),
})
icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
@@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
pkt.Data().AppendView(data)
dataRange := pkt.Data().AsRange()
icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpv6,
- Src: r.LocalAddress(),
- Dst: r.RemoteAddress(),
+ Src: src,
+ Dst: dst,
PayloadCsum: dataRange.Checksum(),
PayloadLen: dataRange.Size(),
}))
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V6.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
+ return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- localPort := uint16(0)
- switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.ident
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ nextID, err := e.registerWithStack(netProto, nextID)
+ if err != nil {
+ return err
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress(),
- LocalPort: localPort,
- RemoteAddress: r.RemoteAddress(),
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- id, err = e.registerWithStack(nicID, netProtos, id)
+ e.ident = nextID.LocalPort
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- e.ID = id
- e.route = r
- e.RegisterNICID = nicID
-
- e.state = stateConnected
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- e.shutdownFlags |= flags
- if e.state != stateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
}
if flags&tcpip.ShutdownRead != 0 {
@@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
- return id, err
+ return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
+ err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
@@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err := e.registerWithStack(boundNetProto, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ e.ident = id.LocalPort
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.RegisterNICID = addr.NIC
-
- // Mark endpoint as bound.
- e.state = stateBound
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -688,21 +567,24 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
return nil
}
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast ||
+ header.IsV4MulticastAddress(addr) ||
+ header.IsV6MulticastAddress(addr) ||
+ e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr)
+}
+
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- err := e.bindLocked(addr)
- if err != nil {
- return err
+ if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) {
+ return &tcpip.ErrBadLocalAddress{}
}
- e.BindNICID = addr.NIC
- e.BindAddr = addr.Addr
+ e.mu.Lock()
+ defer e.mu.Unlock()
- return nil
+ return e.bindLocked(addr)
}
// GetLocalAddress returns the address to which the endpoint is bound.
@@ -710,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.ident
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -722,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
- return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+ if addr, connected := e.net.GetRemoteAddress(); connected {
+ return addr, nil
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -755,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
- switch e.NetProto {
+ switch e.net.NetProto() {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
@@ -829,9 +705,9 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ defer e.mu.RUnlock()
+ ret := e.net.Info()
+ ret.ID.LocalPort = e.ident
return &ret
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index b8b839e4a..dfe453ff9 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,13 @@
package icmp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.thaw()
+
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- var err tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ info.ID, err = e.registerWithStack(info.NetProto, info.ID)
if err != nil {
- panic(err)
+ panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
}
-
- e.ID.LocalAddress = e.route.LocalAddress()
- } else if len(e.ID.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
- if err != nil {
- panic(err)
+ e.ident = info.ID.LocalPort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
index cc950cbde..729f50e9a 100644
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s
t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.AddRoute(tcpip.Route{
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index b1edce39b..3818cb04e 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -9,6 +9,7 @@ go_library(
"endpoint_state.go",
],
visibility = [
+ "//pkg/tcpip/transport/icmp:__pkg__",
"//pkg/tcpip/transport/raw:__pkg__",
"//pkg/tcpip/transport/udp:__pkg__",
],
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index 09b629022..fb31e5104 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -38,31 +38,65 @@ type Endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- // state holds a transport.DatagramBasedEndpointState.
- //
- // state must be read from/written to atomically.
- state uint32
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
wasBound bool
- info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
- owner tcpip.PacketOwner
- writeShutdown bool
- effectiveNetProto tcpip.NetworkProtocolNumber
- connectedRoute *stack.Route `state:"manual"`
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
ttl uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastTTL uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastAddr tcpip.Address
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastNICID tcpip.NICID
- ipv4TOS uint8
- ipv6TClass uint8
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
+
+ // Lock ordering: mu > infoMu.
+ infoMu sync.RWMutex `state:"nosave"`
+ // info has a dedicated mutex so that we can avoid lock ordering violations
+ // when reading the endpoint's info. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling Info()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setInfo.
+ //
+ // +checklocks:infoMu
+ info stack.TransportEndpointInfo
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be accessed with atomics so that we can avoid lock ordering
+ // violations when reading the state. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling State()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setEndpointState.
+ //
+ // +checkatomics
+ state uint32
}
// +stateify savable
@@ -73,8 +107,11 @@ type multicastMembership struct {
// Init initializes the endpoint.
func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
- if e.multicastMemberships != nil {
- panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
}
switch netProto {
@@ -89,8 +126,6 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
netProto: netProto,
transProto: transProto,
- state: uint32(transport.DatagramEndpointStateInitial),
-
info: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: transProto,
@@ -100,6 +135,10 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
multicastTTL: 1,
multicastMemberships: make(map[multicastMembership]struct{}),
}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
}
// NetProto returns the network protocol the endpoint was initialized with.
@@ -107,7 +146,12 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
return e.netProto
}
-// setState sets the state of the endpoint.
+// setEndpointState sets the state of the endpoint.
+//
+// e.mu must be held to synchronize changes to state with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
atomic.StoreUint32(&e.state, uint32(state))
}
@@ -242,23 +286,24 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
if nicID == 0 {
nicID = tcpip.NICID(e.ops.GetBindToDevice())
}
- if e.info.BindNICID != 0 {
- if nicID != 0 && nicID != e.info.BindNICID {
+ info := e.Info()
+ if info.BindNICID != 0 {
+ if nicID != 0 && nicID != info.BindNICID {
return WriteContext{}, &tcpip.ErrNoRoute{}
}
- nicID = e.info.BindNICID
+ nicID = info.BindNICID
}
if nicID == 0 {
- nicID = e.info.RegisterNICID
+ nicID = info.RegisterNICID
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
+ dst, netProto, err := e.checkV4Mapped(*opts.To)
if err != nil {
return WriteContext{}, err
}
- route, _, err = e.connectRoute(nicID, dst, netProto)
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
if err != nil {
return WriteContext{}, err
}
@@ -297,26 +342,30 @@ func (e *Endpoint) Disconnect() {
return
}
+ info := e.Info()
// Exclude ephemerally bound endpoints.
if e.wasBound {
- e.info.ID = stack.TransportEndpointID{
- LocalAddress: e.info.BindAddr,
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: info.BindAddr,
}
e.setEndpointState(transport.DatagramEndpointStateBound)
} else {
- e.info.ID = stack.TransportEndpointID{}
+ info.ID = stack.TransportEndpointID{}
e.setEndpointState(transport.DatagramEndpointStateInitial)
}
+ e.setInfo(info)
e.connectedRoute.Release()
e.connectedRoute = nil
}
-// connectRoute establishes a route to the specified interface or the
+// connectRouteRLocked establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.info.ID.LocalAddress
+//
+// +checklocksread:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+ localAddr := e.Info().ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
@@ -359,42 +408,43 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.mu.Lock()
defer e.mu.Unlock()
+ info := e.Info()
nicID := addr.NIC
switch e.State() {
case transport.DatagramEndpointStateInitial:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
- if e.info.BindNICID == 0 {
+ if info.BindNICID == 0 {
break
}
- if nicID != 0 && nicID != e.info.BindNICID {
+ if nicID != 0 && nicID != info.BindNICID {
return &tcpip.ErrInvalidEndpointState{}
}
- nicID = e.info.BindNICID
+ nicID = info.BindNICID
default:
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4Mapped(addr)
if err != nil {
return err
}
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
if err != nil {
return err
}
id := stack.TransportEndpointID{
- LocalAddress: e.info.ID.LocalAddress,
+ LocalAddress: info.ID.LocalAddress,
RemoteAddress: r.RemoteAddress(),
}
if e.State() == transport.DatagramEndpointStateInitial {
id.LocalAddress = r.LocalAddress()
}
- if err := f(r.NetProto(), e.info.ID, id); err != nil {
+ if err := f(r.NetProto(), info.ID, id); err != nil {
return err
}
@@ -403,8 +453,9 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.connectedRoute.Release()
}
e.connectedRoute = r
- e.info.ID = id
- e.info.RegisterNICID = nicID
+ info.ID = id
+ info.RegisterNICID = nicID
+ e.setInfo(info)
e.effectiveNetProto = netProto
e.setEndpointState(transport.DatagramEndpointStateConnected)
return nil
@@ -426,10 +477,11 @@ func (e *Endpoint) Shutdown() tcpip.Error {
}
}
-// checkV4MappedLocked determines the effective network protocol and converts
+// checkV4MappedRLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ info := e.Info()
+ unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -464,7 +516,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4Mapped(addr)
if err != nil {
return err
}
@@ -483,12 +535,14 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
e.wasBound = true
- e.info.ID = stack.TransportEndpointID{
+ info := e.Info()
+ info.ID = stack.TransportEndpointID{
LocalAddress: addr.Addr,
}
- e.info.BindNICID = addr.NIC
- e.info.RegisterNICID = nicID
- e.info.BindAddr = addr.Addr
+ info.BindNICID = addr.NIC
+ info.RegisterNICID = nicID
+ info.BindAddr = addr.Addr
+ e.setInfo(info)
e.effectiveNetProto = netProto
e.setEndpointState(transport.DatagramEndpointStateBound)
return nil
@@ -506,13 +560,14 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.info.BindAddr
+ info := e.Info()
+ addr := info.BindAddr
if e.State() == transport.DatagramEndpointStateConnected {
addr = e.connectedRoute.LocalAddress()
}
return tcpip.FullAddress{
- NIC: e.info.RegisterNICID,
+ NIC: info.RegisterNICID,
Addr: addr,
}
}
@@ -528,7 +583,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
return tcpip.FullAddress{
Addr: e.connectedRoute.RemoteAddress(),
- NIC: e.info.RegisterNICID,
+ NIC: e.Info().RegisterNICID,
}, true
}
@@ -610,7 +665,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
+ fa, netProto, err := e.checkV4Mapped(fa)
if err != nil {
return err
}
@@ -634,7 +689,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
- if e.info.BindNICID != 0 && e.info.BindNICID != nic {
+ if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
return &tcpip.ErrInvalidEndpointState{}
}
@@ -737,7 +792,19 @@ func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
// Info returns a copy of the endpoint info.
func (e *Endpoint) Info() stack.TransportEndpointInfo {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.infoMu.RLock()
+ defer e.infoMu.RUnlock()
return e.info
}
+
+// setInfo sets the endpoint's info.
+//
+// e.mu must be held to synchronize changes to info with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
+ e.infoMu.Lock()
+ defer e.infoMu.Unlock()
+ e.info = info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
index 858007156..68bd1fbf6 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_state.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -35,20 +35,22 @@ func (e *Endpoint) Resume(s *stack.Stack) {
}
}
+ info := e.Info()
+
switch state := e.State(); state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound:
- if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
- if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 {
- panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress))
+ if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress))
}
}
case transport.DatagramEndpointStateConnected:
var err tcpip.Error
multicastLoop := e.ops.GetMulticastLoop()
- e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
if err != nil {
- panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
}
default:
panic(fmt.Sprintf("unhandled state = %s", state))
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
index d99c961c3..f263a9ea2 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_test.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -124,11 +124,20 @@ func TestEndpointStateTransitions(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -257,11 +266,19 @@ func TestBindNICID(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
}
var ops tcpip.SocketOptions
diff --git a/pkg/tcpip/transport/internal/noop/BUILD b/pkg/tcpip/transport/internal/noop/BUILD
new file mode 100644
index 000000000..171c41eb1
--- /dev/null
+++ b/pkg/tcpip/transport/internal/noop/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "noop",
+ srcs = ["endpoint.go"],
+ visibility = ["//pkg/tcpip/transport/raw:__pkg__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/internal/noop/endpoint.go b/pkg/tcpip/transport/internal/noop/endpoint.go
new file mode 100644
index 000000000..443b4e416
--- /dev/null
+++ b/pkg/tcpip/transport/internal/noop/endpoint.go
@@ -0,0 +1,172 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package noop contains an endpoint that implements all tcpip.Endpoint
+// functions as noops.
+package noop
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// endpoint can be created, but all interactions have no effect or
+// return errors.
+//
+// +stateify savable
+type endpoint struct {
+ tcpip.DefaultSocketOptionsHandler
+ ops tcpip.SocketOptions
+}
+
+// New returns an initialized noop endpoint.
+func New(stk *stack.Stack) tcpip.Endpoint {
+ // ep.ops must be in a valid, initialized state for callers of
+ // ep.SocketOptions.
+ var ep endpoint
+ ep.ops.InitHandler(&ep, stk, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
+ return &ep
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (*endpoint) Abort() {
+ // No-op.
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (*endpoint) Close() {
+ // No-op.
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (*endpoint) ModerateRecvBuf(int) {
+ // No-op.
+}
+
+func (*endpoint) SetOwner(tcpip.PacketOwner) {
+ // No-op.
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (*endpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
+ return tcpip.ReadResult{}, &tcpip.ErrNotPermitted{}
+}
+
+// Write implements tcpip.Endpoint.Write.
+func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
+ return 0, &tcpip.ErrNotPermitted{}
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// Connect implements tcpip.Endpoint.Connect.
+func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error {
+ return &tcpip.ErrNotPermitted{}
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown.
+func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
+ return &tcpip.ErrNotPermitted{}
+}
+
+// Listen implements tcpip.Endpoint.Listen.
+func (*endpoint) Listen(int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// Accept implements tcpip.Endpoint.Accept.
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (*endpoint) Bind(tcpip.FullAddress) tcpip.Error {
+ return &tcpip.ErrNotPermitted{}
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (*endpoint) Readiness(waiter.EventMask) waiter.EventMask {
+ return 0
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
+}
+
+func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (*endpoint) GetSockOptInt(tcpip.SockOptInt) (int, tcpip.Error) {
+ return 0, &tcpip.ErrUnknownProtocolOption{}
+}
+
+// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
+func (*endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ panic(fmt.Sprintf("unreachable: noop.endpoint should never be registered, but got packet: %+v", pkt))
+}
+
+// State implements socket.Socket.State.
+func (*endpoint) State() uint32 {
+ return 0
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {
+ // No-op.
+}
+
+// LastError implements tcpip.Endpoint.LastError.
+func (*endpoint) LastError() tcpip.Error {
+ return nil
+}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &ep.ops
+}
+
+// Info implements tcpip.Endpoint.Info.
+func (*endpoint) Info() tcpip.EndpointInfo {
+ return &stack.TransportEndpointInfo{}
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (*endpoint) Stats() tcpip.EndpointStats {
+ return &tcpip.TransportEndpointStats{}
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 0554d2f4a..80eef39e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -59,52 +59,47 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
-
- // The following fields are used to manage the receive queue and are
- // protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats
+
+ // The following fields are used to manage the receive queue.
+ rcvMu sync.Mutex `state:"nosave"`
+ // +checklocks:rcvMu
+ rcvList packetList
+ // +checklocks:rcvMu
rcvBufSize int
- rcvClosed bool
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ // +checklocks:rcvMu
+ rcvClosed bool
+ // +checklocks:rcvMu
+ rcvDisabled bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ closed bool
+ // +checklocks:mu
+ boundNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
boundNIC tcpip.NICID
- // lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
- lastError tcpip.Error
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
- // frozen indicates if the packets should be delivered to the endpoint
- // during restore.
- frozen bool
+ // +checklocks:lastErrorMu
+ lastError tcpip.Error
}
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: s,
+ cooked: cooked,
+ boundNetProto: netProto,
+ waiterQueue: waiterQueue,
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
@@ -140,7 +135,7 @@ func (ep *endpoint) Close() {
return
}
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
ep.rcvMu.Lock()
defer ep.rcvMu.Unlock()
@@ -153,7 +148,6 @@ func (ep *endpoint) Close() {
}
ep.closed = true
- ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -188,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
Total: packet.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: packet.receivedAt.UnixNano(),
+ Timestamp: packet.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -214,13 +208,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc
ep.mu.Lock()
closed := ep.closed
nicID := ep.boundNIC
+ proto := ep.boundNetProto
ep.mu.Unlock()
if closed {
return 0, &tcpip.ErrClosedForSend{}
}
var remote tcpip.LinkAddress
- proto := ep.netProto
if to := opts.To; to != nil {
remote = tcpip.LinkAddress(to.Addr)
@@ -296,29 +290,42 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
- // If the NIC being bound is the same then just return success.
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if netProto == 0 {
+ // Do not allow unbinding the network protocol.
+ netProto = ep.boundNetProto
+ }
+
+ if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto {
+ // Already bound to the requested NIC and network protocol.
return nil
}
- // Unregister endpoint with all the nics.
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
- ep.bound = false
+ // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new
+ // binding.
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
+ ep.boundNIC = 0
+ ep.boundNetProto = 0
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
- ep.bound = true
ep.boundNIC = addr.NIC
-
+ ep.boundNetProto = netProto
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: ep.boundNIC,
+ Port: uint16(ep.boundNetProto),
+ }, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -402,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -414,7 +421,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
rcvBufSize := ep.ops.GetReceiveBufferSize()
- if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
+ if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -473,10 +480,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto}
}
// Stats returns a pointer to the endpoint stats.
@@ -491,18 +496,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
-
-// freeze prevents any more packets from being delivered to the endpoint.
-func (ep *endpoint) freeze() {
- ep.mu.Lock()
- ep.frozen = true
- ep.mu.Unlock()
-}
-
-// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
-// new packets to be delivered again.
-func (ep *endpoint) thaw() {
- ep.mu.Lock()
- ep.frozen = false
- ep.mu.Unlock()
-}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 5c688d286..88cd80ad3 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -44,17 +45,24 @@ func (p *packet) loadData(data buffer.VectorisedView) {
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- ep.freeze()
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+ ep.rcvDisabled = true
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- ep.thaw()
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
- if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(err)
+ if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
+ panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
+
+ ep.rcvMu.Lock()
+ ep.rcvDisabled = false
+ ep.rcvMu.Unlock()
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index b7e97e218..10b0c35fb 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -35,6 +35,7 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport",
"//pkg/tcpip/transport/internal/network",
+ "//pkg/tcpip/transport/internal/noop",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 3040a445b..ce76774af 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -49,6 +49,7 @@ type rawPacket struct {
receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
}
// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
@@ -70,7 +71,7 @@ type endpoint struct {
associated bool
net network.Endpoint
- stats tcpip.TransportEndpointStats `state:"nosave"`
+ stats tcpip.TransportEndpointStats
ops tcpip.SocketOptions
// The following fields are used to manage the receive queue and are
@@ -202,12 +203,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: pkt.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: pkt.receivedAt.UnixNano(),
+ Timestamp: pkt.receivedAt,
},
}
if opts.NeedRemoteAddr {
res.RemoteAddr = pkt.senderAddr
}
+ switch netProto := e.net.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceivePacketInfo() {
+ res.ControlMessages.HasIPPacketInfo = true
+ res.ControlMessages.PacketInfo = pkt.packetInfo
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ res.ControlMessages.HasIPv6PacketInfo = true
+ res.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: pkt.packetInfo.NIC,
+ Addr: pkt.packetInfo.DestinationAddr,
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", netProto))
+ }
n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
@@ -435,7 +453,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return false
}
- srcAddr := pkt.Network().SourceAddress()
+ net := pkt.Network()
+ dstAddr := net.DestinationAddress()
+ srcAddr := net.SourceAddress()
info := e.net.Info()
switch state := e.net.State(); state {
@@ -457,7 +477,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
// If bound to an address, only accept data for that address.
- if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ if info.BindAddr != "" && info.BindAddr != dstAddr {
return false
}
default:
@@ -472,6 +492,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
NIC: pkt.NICID,
Addr: srcAddr,
},
+ packetInfo: tcpip.IPPacketInfo{
+ // TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast
+ // address. LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ LocalAddr: dstAddr,
+ DestinationAddr: dstAddr,
+ NIC: pkt.NICID,
+ },
}
// Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
@@ -483,10 +511,10 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// overlapping slices.
var combinedVV buffer.VectorisedView
if info.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
+ networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader))
+ headers = append(headers, networkHeader...)
+ headers = append(headers, transportHeader...)
combinedVV = headers.ToVectorisedView()
} else {
combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index e393b993d..624e2dbe7 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -17,6 +17,7 @@ package raw
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/noop"
"gvisor.dev/gvisor/pkg/tcpip/transport/packet"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -33,3 +34,18 @@ func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpi
func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
}
+
+// CreateOnlyFactory implements stack.RawFactory. It allows creation of raw
+// endpoints that do not support reading, writing, binding, etc.
+type CreateOnlyFactory struct{}
+
+// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
+func (CreateOnlyFactory) NewUnassociatedEndpoint(stk *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ return noop.New(stk), nil
+}
+
+// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
+func (CreateOnlyFactory) NewPacketEndpoint(*stack.Stack, bool, tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ // This isn't needed by anything, so it isn't implemented.
+ return nil, &tcpip.ErrNotPermitted{}
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 5148fe157..20958d882 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -80,9 +80,10 @@ go_library(
go_test(
name = "tcp_x_test",
- size = "medium",
+ size = "large",
srcs = [
"dual_stack_test.go",
+ "rcv_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
"tcp_rack_test.go",
@@ -114,16 +115,6 @@ go_test(
)
go_test(
- name = "rcv_test",
- size = "small",
- srcs = ["rcv_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
name = "tcp_test",
size = "small",
srcs = [
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 03c9fafa1..caf14b0dc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -15,12 +15,12 @@
package tcp
import (
+ "container/list"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -100,18 +100,6 @@ type listenContext struct {
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
netProto tcpip.NetworkProtocolNumber
-
- // pendingMu protects pendingEndpoints. This should only be accessed
- // by the listening endpoint's worker goroutine.
- //
- // Lock Ordering: listenEP.workerMu -> pendingMu
- pendingMu sync.Mutex
- // pending is used to wait for all pendingEndpoints to finish when
- // a socket is closed.
- pending sync.WaitGroup
- // pendingEndpoints is a map of all endpoints for which a handshake is
- // in progress.
- pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 {
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stk,
- protocol: protocol,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6Only: v6Only,
- netProto: netProto,
- listenEP: listenEP,
- pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
+ stack: stk,
+ protocol: protocol,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6Only: v6Only,
+ netProto: netProto,
+ listenEP: listenEP,
}
for i := range l.nonce {
@@ -193,14 +180,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
-func (l *listenContext) useSynCookies() bool {
- var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
- if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
- panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
- }
- return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
-}
-
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
@@ -273,18 +252,15 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
return nil, &tcpip.ErrConnectionAborted{}
}
- l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
- l.listenEP.propagateInheritableOptionsLocked(ep)
+ l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce
if !ep.reserveTupleLocked() {
ep.mu.Unlock()
ep.Close()
- l.removePendingEndpoint(ep)
-
return nil, &tcpip.ErrConnectionAborted{}
}
@@ -303,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
ep.drainClosingSegmentQueue()
return nil, err
@@ -344,39 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
return ep, nil
}
-func (l *listenContext) addPendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- l.pendingEndpoints[n.TransportEndpointInfo.ID] = n
- l.pending.Add(1)
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) removePendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.TransportEndpointInfo.ID)
- l.pending.Done()
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) closeAllPendingEndpoints() {
- l.pendingMu.Lock()
- for _, n := range l.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- l.pendingMu.Unlock()
- l.pending.Wait()
-}
-
-// Precondition: h.ep.mu must be held.
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.drainClosingSegmentQueue()
e.h = nil
}
@@ -384,12 +329,9 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// cleanupCompletedHandshake transfers any state from the completed handshake to
// the new endpoint.
//
-// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@@ -401,47 +343,11 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e.h = nil
}
-// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// listener has transitioned out of the listen state (accepted is the zero
-// value), the new endpoint is reset instead.
-func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
- // Drop the lock before notifying to avoid deadlock in user-specified
- // callbacks.
- delivered := func() bool {
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
- for {
- if e.accepted == (accepted{}) {
- return false
- }
- if e.accepted.endpoints.Len() == e.accepted.cap {
- e.acceptCond.Wait()
- continue
- }
-
- e.accepted.endpoints.PushBack(n)
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
- }
- return true
- }
- }()
- if delivered {
- e.waiterQueue.Notify(waiter.ReadableEvents)
- } else {
- n.notifyProtocolGoroutine(notifyReset)
- }
-}
-
// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
//
-// Precondition: e.mu and n.mu must be held.
+// +checklocks:e.mu
+// +checklocks:n.mu
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
n.portFlags = e.portFlags
@@ -452,9 +358,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// reserveTupleLocked reserves an accepted endpoint's tuple.
//
-// Preconditions:
-// * propagateInheritableOptionsLocked has been called.
-// * e.mu is held.
+// Precondition: e.propagateInheritableOptionsLocked has been called.
+//
+// +checklocks:e.mu
func (e *endpoint) reserveTupleLocked() bool {
dest := tcpip.FullAddress{
Addr: e.TransportEndpointInfo.ID.RemoteAddress,
@@ -489,70 +395,36 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// handleSynSegment is called in its own goroutine once the listening endpoint
-// receives a SYN segment. It is responsible for completing the handshake and
-// queueing the new endpoint for acceptance.
-//
-// A limited number of these goroutines are allowed before TCP starts using SYN
-// cookies to accept connections.
-//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error {
- defer s.decRef()
+func (e *endpoint) acceptQueueIsFull() bool {
+ e.acceptMu.Lock()
+ full := e.acceptQueue.isFull()
+ e.acceptMu.Unlock()
+ return full
+}
- h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
- if err != nil {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- atomic.AddInt32(&e.synRcvdCount, -1)
- return err
- }
+// +stateify savable
+type acceptQueue struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
- go func() {
- // Note that startHandshake returns a locked endpoint. The
- // force call here just makes it so.
- if err := h.complete(); err != nil { // +checklocksforce
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- ctx.cleanupFailedHandshake(h)
- atomic.AddInt32(&e.synRcvdCount, -1)
- return
- }
- ctx.cleanupCompletedHandshake(h)
- h.ep.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep, false /*withSynCookie*/)
- }()
+ // pendingEndpoints is a set of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[*endpoint]struct{}
- return nil
+ // capacity is the maximum number of endpoints that can be in endpoints.
+ capacity int
}
-func (e *endpoint) synRcvdBacklogFull() bool {
- e.acceptMu.Lock()
- acceptedCap := e.accepted.cap
- e.acceptMu.Unlock()
- // The capacity of the accepted queue would always be one greater than the
- // listen backlog. But, the SYNRCVD connections count is always checked
- // against the listen backlog value for Linux parity reason.
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- //
- // We maintain an equality check here as the synRcvdCount is incremented
- // and compared only from a single listener context and the capacity of
- // the accepted queue can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
-}
-
-func (e *endpoint) acceptQueueIsFull() bool {
- e.acceptMu.Lock()
- full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
- e.acceptMu.Unlock()
- return full
+func (a *acceptQueue) isFull() bool {
+ return a.endpoints.Len() == a.capacity
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+// +checklocks:e.mu
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
e.rcvQueueInfo.rcvQueueMu.Lock()
rcvClosed := e.rcvQueueInfo.RcvClosed
@@ -580,11 +452,95 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
opts := parseSynSegmentOptions(s)
- if !ctx.useSynCookies() {
- s.incRef()
- atomic.AddInt32(&e.synRcvdCount, 1)
- return e.handleSynSegment(ctx, s, opts)
+
+ useSynCookies, err := func() (bool, tcpip.Error) {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ if alwaysUseSynCookies {
+ return true, nil
+ }
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 {
+ return true, nil
+ }
+
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return false, err
+ }
+
+ e.acceptQueue.pendingEndpoints[h.ep] = struct{}{}
+ e.pendingAccepted.Add(1)
+
+ go func() {
+ defer func() {
+ e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ delete(e.acceptQueue.pendingEndpoints, h.ep)
+ }()
+
+ // Note that startHandshake returns a locked endpoint. The force call
+ // here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ // Deliver the endpoint to the accept queue.
+ //
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ // The listener is transitioning out of the Listen state; bail.
+ if e.acceptQueue.capacity == 0 {
+ return false
+ }
+ if e.acceptQueue.isFull() {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.acceptQueue.endpoints.PushBack(h.ep)
+ return true
+ }
+ }()
+
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ h.ep.notifyProtocolGoroutine(notifyReset)
+ }
+ }()
+
+ return false, nil
+ }()
+ if err != nil {
+ return err
+ }
+ if !useSynCookies {
+ return nil
}
+
route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
@@ -627,18 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
case s.flags.Contains(header.TCPFlagAck):
- if e.acceptQueueIsFull() {
- // Silently drop the ack as the application can't accept
- // the connection at this point. The ack will be
- // retransmitted by the sender anyway and we can
- // complete the connection at the time of retransmit if
- // the backlog has space.
- e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -674,6 +618,24 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// ACK was received from the sender.
return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
+
+ // Keep hold of acceptMu until the new endpoint is in the accept queue (or
+ // if there is an error), to guarantee that we will keep our spot in the
+ // queue even if another handshake from the syn queue completes.
+ e.acceptMu.Lock()
+ if e.acceptQueue.isFull() {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.acceptMu.Unlock()
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
rcvdSynOptions := header.TCPSynOptions{
@@ -695,6 +657,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
+ e.acceptMu.Unlock()
return err
}
@@ -706,6 +669,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !n.reserveTupleLocked() {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -723,6 +687,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.boundBindToDevice,
); err != nil {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -755,20 +720,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.newSegmentWaker.Assert()
}
- // Do the delivery in a separate goroutine so
- // that we don't block the listen loop in case
- // the application is slow to accept or stops
- // accepting.
- //
- // NOTE: This won't result in an unbounded
- // number of goroutines as we do check before
- // entering here that there was at least some
- // space available in the backlog.
-
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n, true /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.acceptQueue.endpoints.PushBack(n)
+ e.acceptMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.ReadableEvents)
return nil
default:
@@ -785,14 +745,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
- // Mark endpoint as closed. This will prevent goroutines running
- // handleSynSegment() from attempting to queue new connections
- // to the endpoint.
e.setEndpointState(StateClose)
- // Close any endpoints in SYN-RCVD state.
- ctx.closeAllPendingEndpoints()
-
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5d8e18484..80cd07218 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -30,6 +30,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// InitialRTO is the initial retransmission timeout.
+// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142
+const InitialRTO = time.Second
+
// maxSegmentsPerWake is the maximum number of segments to process in the main
// protocol goroutine per wake-up. Yielding [after this number of segments are
// processed] allows other events to be processed as well (e.g., timeouts,
@@ -532,7 +536,7 @@ func (h *handshake) complete() tcpip.Error {
defer s.Done()
// Initialize the resend timer.
- timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert)
+ timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert)
if err != nil {
return err
}
@@ -578,6 +582,9 @@ func (h *handshake) complete() tcpip.Error {
if (n&notifyClose)|(n&notifyAbort) != 0 {
return &tcpip.ErrAborted{}
}
+ if n&notifyShutdown != 0 {
+ return &tcpip.ErrConnectionReset{}
+ }
if n&notifyDrain != 0 {
for !h.ep.segmentQueue.empty() {
s := h.ep.segmentQueue.dequeue()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d2b8f298f..066ffe051 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,7 +15,6 @@
package tcp
import (
- "container/list"
"encoding/binary"
"fmt"
"io"
@@ -187,6 +186,8 @@ const (
// say TIME_WAIT.
notifyTickleWorker
notifyError
+ // notifyShutdown means that a connecting socket was shutdown.
+ notifyShutdown
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -203,6 +204,8 @@ type SACKInfo struct {
}
// ReceiveErrors collect segment receive errors within transport layer.
+//
+// +stateify savable
type ReceiveErrors struct {
tcpip.ReceiveErrors
@@ -232,6 +235,8 @@ type ReceiveErrors struct {
}
// SendErrors collect segment send errors within the transport layer.
+//
+// +stateify savable
type SendErrors struct {
tcpip.SendErrors
@@ -255,6 +260,8 @@ type SendErrors struct {
}
// Stats holds statistics about the endpoint.
+//
+// +stateify savable
type Stats struct {
// SegmentsReceived is the number of TCP segments received that
// the transport layer successfully parsed.
@@ -309,15 +316,6 @@ type rcvQueueInfo struct {
rcvQueue segmentList `state:"wait"`
}
-// +stateify savable
-type accepted struct {
- // NB: this could be an endpointList, but ilist only permits endpoints to
- // belong to one list at a time, and endpoints are already stored in the
- // dispatcher's list.
- endpoints list.List `state:".([]*endpoint)"`
- cap int
-}
-
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -333,7 +331,7 @@ type accepted struct {
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects accepted.
+// e.acceptMu -> Protects e.acceptQueue.
// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -497,10 +495,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
- // synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state; this is only accessed atomically.
- synRcvdCount int32
-
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16
@@ -573,7 +567,8 @@ type endpoint struct {
// accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- accepted accepted
+ // +checklocks:acceptMu
+ acceptQueue acceptQueue
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -606,8 +601,7 @@ type endpoint struct {
gso stack.GSO
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats Stats `state:"nosave"`
+ stats Stats
// tcpLingerTimeout is the maximum amount of a time a socket
// a socket stays in TIME_WAIT state before being marked
@@ -819,10 +813,9 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto
waiterQueue: waiterQueue,
state: uint32(StateInitial),
keepalive: keepalive{
- // Linux defaults.
- idle: 2 * time.Hour,
- interval: 75 * time.Second,
- count: 9,
+ idle: DefaultKeepaliveIdle,
+ interval: DefaultKeepaliveInterval,
+ count: DefaultKeepaliveCount,
},
uniqueID: s.UniqueID(),
txHash: s.Rand().Uint32(),
@@ -904,7 +897,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if e.accepted.endpoints.Len() != 0 {
+ if e.acceptQueue.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -1087,20 +1080,20 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- acceptedCopy := e.accepted
- e.accepted = accepted{}
- e.acceptMu.Unlock()
-
- if acceptedCopy == (accepted{}) {
- return
+ // Close any endpoints in SYN-RCVD state.
+ for n := range e.acceptQueue.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
}
-
- e.acceptCond.Broadcast()
-
+ e.acceptQueue.pendingEndpoints = nil
// Reset all connections that are waiting to be accepted.
- for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() {
n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
+ e.acceptQueue.endpoints.Init()
+ e.acceptMu.Unlock()
+
+ e.acceptCond.Broadcast()
+
// Wait for reset of all endpoints that are still waiting to be delivered to
// the now closed accepted.
e.pendingAccepted.Wait()
@@ -2060,7 +2053,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err
@@ -2380,6 +2373,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
+
+ if e.EndpointState().connecting() {
+ // When calling shutdown(2) on a connecting socket, the endpoint must
+ // enter the error state. But this logic cannot belong to the shutdownLocked
+ // method because that method is called during a close(2) (and closing a
+ // connecting socket is not an error).
+ e.resetConnectionLocked(&tcpip.ErrConnectionReset{})
+ e.notifyProtocolGoroutine(notifyShutdown)
+ e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
+ return nil
+ }
+
return e.shutdownLocked(flags)
}
@@ -2480,22 +2485,23 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.accepted == (accepted{}) {
- // listen is called after shutdown.
- e.accepted.cap = backlog
- e.shutdownFlags = 0
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcvQueueInfo.RcvClosed = false
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- } else {
- // Adjust the size of the backlog iff we can fit
- // existing pending connections into the new one.
- if e.accepted.endpoints.Len() > backlog {
- return &tcpip.ErrInvalidEndpointState{}
- }
- e.accepted.cap = backlog
+
+ // Adjust the size of the backlog iff we can fit
+ // existing pending connections into the new one.
+ if e.acceptQueue.endpoints.Len() > backlog {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+ e.acceptQueue.capacity = backlog
+
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
}
+ e.shutdownFlags = 0
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = false
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+
// Notify any blocked goroutines that they can attempt to
// deliver endpoints again.
e.acceptCond.Broadcast()
@@ -2530,8 +2536,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.accepted == (accepted{}) {
- e.accepted.cap = backlog
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+ if e.acceptQueue.capacity == 0 {
+ e.acceptQueue.capacity = backlog
}
e.acceptMu.Unlock()
@@ -2571,8 +2580,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
// Get the new accepted endpoint.
var n *endpoint
e.acceptMu.Lock()
- if element := e.accepted.endpoints.Front(); element != nil {
- n = e.accepted.endpoints.Remove(element).(*endpoint)
+ if element := e.acceptQueue.endpoints.Front(); element != nil {
+ n = e.acceptQueue.endpoints.Remove(element).(*endpoint)
}
e.acceptMu.Unlock()
if n == nil {
@@ -2989,6 +2998,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
}
s.Sender.RACKState = e.snd.rc.TCPRACKState
+ s.Sender.RetransmitTS = e.snd.retransmitTS
+ s.Sender.SpuriousRecovery = e.snd.spuriousRecovery
return s
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index f2e8b3840..94072a115 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() {
}
// saveEndpoints is invoked by stateify.
-func (a *accepted) saveEndpoints() []*endpoint {
+func (a *acceptQueue) saveEndpoints() []*endpoint {
acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
acceptedEndpoints[i] = e.Value.(*endpoint)
@@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint {
}
// loadEndpoints is invoked by stateify.
-func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) {
for _, ep := range acceptedEndpoints {
a.endpoints.PushBack(ep)
}
@@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
- backlog := e.accepted.cap
+ e.acceptMu.Lock()
+ backlog := e.acceptQueue.capacity
+ e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index e4410ad93..f122ea009 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -66,6 +66,18 @@ const (
// DefaultSynRetries is the default value for the number of SYN retransmits
// before a connect is aborted.
DefaultSynRetries = 6
+
+ // DefaultKeepaliveIdle is the idle time for a connection before keep-alive
+ // probes are sent.
+ DefaultKeepaliveIdle = 2 * time.Hour
+
+ // DefaultKeepaliveInterval is the time between two successive keep-alive
+ // probes.
+ DefaultKeepaliveInterval = 75 * time.Second
+
+ // DefaultKeepaliveCount is the number of keep-alive probes that are sent
+ // before declaring the connection dead.
+ DefaultKeepaliveCount = 9
)
const (
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
index 8a026ec46..e47a07030 100644
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rcv_test
+package tcp_test
import (
"testing"
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
index 2e6ea06f5..2d5fdda19 100644
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW
DataSize: seg.data.Size(),
SegMemSize: seg.segMemSize(),
}
- if diff := cmp.Diff(got, want); diff != "" {
+ if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("%s differs (-want +got):\n%s", name, diff)
}
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 2fabf1594..4377f07a0 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -144,6 +144,15 @@ type sender struct {
// probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
probeTimer timer `state:"nosave"`
probeWaker sleep.Waker `state:"nosave"`
+
+ // spuriousRecovery indicates whether the sender entered recovery
+ // spuriously as described in RFC3522 Section 3.2.
+ spuriousRecovery bool
+
+ // retransmitTS is the timestamp at which the sender sends retransmitted
+ // segment after entering an RTO for the first time as described in
+ // RFC3522 Section 3.2.
+ retransmitTS uint32
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -425,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // Initialize the variables used to detect spurious recovery after
+ // entering RTO.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
// TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
// when writeList is empty. Remove this once we have a proper fix for this
// issue.
@@ -495,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveRecovery()
}
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
s.state = tcpip.RTORecovery
s.cc.HandleRTOExpired()
@@ -958,6 +978,13 @@ func (s *sender) sendData() {
}
func (s *sender) enterRecovery() {
+ // Initialize the variables used to detect spurious recovery after
+ // entering recovery.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
s.FastRecovery.Active = true
// Save state to reflect we're now in fast recovery.
//
@@ -972,6 +999,11 @@ func (s *sender) enterRecovery() {
s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding
s.FastRecovery.HighRxt = s.SndUna
s.FastRecovery.RescueRxt = s.SndUna
+
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
if s.ep.SACKPermitted {
s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
@@ -1147,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool {
// Iterate the writeList and update RACK for each segment which is newly acked
// either cumulatively or selectively. Loop through the segments which are
// sacked, and update the RACK related variables and check for reordering.
+// Returns true when the DSACK block has been detected in the received ACK.
//
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
-func (s *sender) walkSACK(rcvdSeg *segment) {
+func (s *sender) walkSACK(rcvdSeg *segment) bool {
s.rc.setDSACKSeen(false)
// Look for DSACK block.
+ hasDSACK := false
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
@@ -1167,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.setDSACKSeen(true)
idx = 1
n--
+ hasDSACK = true
}
if n == 0 {
- return
+ return hasDSACK
}
// Sort the SACK blocks. The first block is the most recent unacked
@@ -1193,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
seg = seg.Next()
}
}
+ return hasDSACK
}
// checkDSACK checks if a DSACK is reported.
@@ -1239,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool {
return false
}
+func (s *sender) recordRetransmitTS() {
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2
+ //
+ // The Eifel detection algorithm is used, only upon initiation of loss
+ // recovery, i.e., when either the timeout-based retransmit or the fast
+ // retransmit is sent. The Eifel detection algorithm MUST NOT be
+ // reinitiated after loss recovery has already started. In particular,
+ // it must not be reinitiated upon subsequent timeouts for the same
+ // segment, and not upon retransmitting segments other than the oldest
+ // outstanding segment, e.g., during selective loss recovery.
+ if s.inRecovery() {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ //
+ // Set a "RetransmitTS" variable to the value of the Timestamp Value
+ // field of the Timestamps option included in the retransmit sent when
+ // loss recovery is initiated. A TCP sender must ensure that
+ // RetransmitTS does not get overwritten as loss recovery progresses,
+ // e.g., in case of a second timeout and subsequent second retransmit of
+ // the same octet.
+ s.retransmitTS = s.ep.tsValNow()
+}
+
+func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) {
+ // Return if the sender has already detected spurious recovery.
+ if s.spuriousRecovery {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4
+ //
+ // If the value of the Timestamp Echo Reply field of the acceptable ACK's
+ // Timestamps option is smaller than the value of RetransmitTS, then
+ // proceed to next step, else return.
+ if tsEchoReply >= s.retransmitTS {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If the acceptable ACK carries a DSACK option [RFC2883], then return.
+ if hasDSACK {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If during the lifetime of the TCP connection the TCP sender has
+ // previously received an ACK with a DSACK option, or the acceptable ACK
+ // does not acknowledge all outstanding data, then proceed to next step,
+ // else return.
+ numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value()
+ if numDSACK == 0 && s.SndUna == s.SndNxt {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6
+ //
+ // If the loss recovery has been initiated with a timeout-based
+ // retransmit, then set
+ // SpuriousRecovery <- SPUR_TO (equal 1),
+ // else set
+ // SpuriousRecovery <- dupacks+1
+ // Set the spurious recovery variable to true as we do not differentiate
+ // between fast, SACK or RTO recovery.
+ s.spuriousRecovery = true
+ s.ep.stack.Stats().TCP.SpuriousRecovery.Increment()
+}
+
+// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state.
+func (s *sender) inRecovery() bool {
+ if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery {
+ return true
+ }
+ return false
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1254,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Insert SACKBlock information into our scoreboard.
+ hasDSACK := false
if s.ep.SACKPermitted {
for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
@@ -1288,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// RACK.fack, then the corresponding packet has been
// reordered and RACK.reord is set to TRUE.
if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
- s.walkSACK(rcvdSeg)
+ hasDSACK = s.walkSACK(rcvdSeg)
}
s.SetPipe()
}
@@ -1418,6 +1534,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Clear SACK information for all acked data.
s.ep.scoreboard.Delete(s.SndUna)
+ // Detect if the sender entered recovery spuriously.
+ if s.inRecovery() {
+ s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr)
+ }
+
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
if !s.FastRecovery.Active {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index c35db7c95..0d36d0dd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -1059,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) {
for i := 0; i < numPkts; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- if i == 0 {
- // Send ACK for the first packet to establish RTT.
- c.SendAck(seq, maxPayload)
- }
}
- // SACK for #10 packet.
- start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ // Expect retransmission of last packet due to TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize)
+
+ // SACK for first and last packet.
+ start := c.IRS.Add(seqnum.Size(maxPayload))
end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}})
+ dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
var info tcpip.TCPInfoOption
if err := c.EP.GetSockOpt(&info); err != nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 6255355bb..896249d2d 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -23,6 +23,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) {
t.Error(err)
}
}
+
+func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) {
+ t.Helper()
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+}
+
+func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) {
+ payloadLen := uint32(len(tcpHdr.Payload()))
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead),
+ checker.TCPAckNum(context.TestInitialSequenceNumber+1),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
+ ),
+ )
+ pdata := data[bytesRead : bytesRead+payloadLen]
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
+ t.Fatalf("got data = %v, want = %v", p, pdata)
+ }
+}
+
+func buildTSOptionFromHeader(tcpHdr header.TCP) []byte {
+ parsedOpts := tcpHdr.ParsedOptions()
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ return tsOpt[:]
+}
+
+func TestDetectSpuriousRecoveryWithRTO(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Expect #5 segment with TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Expect #1 segment because of RTO.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.RTORecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ numAck := 0
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if numAck < 3 {
+ numAck++
+ return
+ }
+
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.SACKRecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestNoSpuriousRecoveryWithDSACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ // Acknowledge the data with DSACK for #1 segment.
+ start = c.IRS.Add(maxPayload + 1)
+ end = start.Add(2 * maxPayload)
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}})
+
+ verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index bc8708a5b..6f1ee3816 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1382,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) {
if err := s.CreateNIC(id, ep); err != nil {
t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
- t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: id},
@@ -1652,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
+func TestShutdownConnectingSocket(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ shutdownMode tcpip.ShutdownFlags
+ }{
+ {"ShutdownRead", tcpip.ShutdownRead},
+ {"ShutdownWrite", tcpip.ShutdownWrite},
+ {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with
+ // the handshake process.
+ c.Create(-1)
+
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ // Start connection attempt.
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
+ }
+
+ // Check the SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ if err := c.EP.Shutdown(test.shutdownMode); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ // The endpoint internal state is updated immediately.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ select {
+ case <-ch:
+ default:
+ t.Fatal("endpoint was not notified")
+ }
+
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, &tcpip.ErrConnectionReset{})
+
+ // If the endpoint is not properly shutdown, it'll re-attempt to connect
+ // by sending another ACK packet.
+ c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond))
+ })
+ }
+}
+
func TestSynSent(t *testing.T) {
for _, test := range []struct {
name string
@@ -1675,7 +1744,7 @@ func TestSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
err := c.EP.Connect(addr)
- if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
@@ -1991,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
)
// Cause a FIN to be generated.
- c.EP.Shutdown(tcpip.ShutdownWrite)
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the FIN but DON't ACK IT.
checker.IPv4(t, c.GetPacket(),
@@ -2007,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// Cause a RST to be generated by closing the read end now since we have
// unread data.
- c.EP.Shutdown(tcpip.ShutdownRead)
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the RST
checker.IPv4(t, c.GetPacket(),
@@ -2145,12 +2218,15 @@ func TestSmallReceiveBufferReadiness(t *testing.T) {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x7f\x00\x00\x01"),
- PrefixLen: 8,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
}
{
@@ -4954,13 +5030,17 @@ func makeStack() (*stack.Stack, tcpip.Error) {
}
for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- address tcpip.Address
+ number tcpip.NetworkProtocolNumber
+ addrWithPrefix tcpip.AddressWithPrefix
}{
- {ipv4.ProtocolNumber, context.StackAddr},
- {ipv6.ProtocolNumber, context.StackV6Addr},
+ {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
+ {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
} {
- if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ct.number,
+ AddressWithPrefix: ct.addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
return nil, err
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 6e55a7a32..88bb99354 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -243,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: StackAddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
@@ -257,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: StackV6AddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 5cc7a2886..d2c0963b0 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -63,5 +63,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4255457f9..077a2325a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -60,9 +60,8 @@ type endpoint struct {
waiterQueue *waiter.Queue
uniqueID uint64
net network.Endpoint
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
- ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -234,7 +233,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
}
switch p.netProto {
@@ -243,19 +242,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
cm.HasTOS = true
cm.TOS = p.tos
}
+
+ if e.ops.GetReceivePacketInfo() {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
case header.IPv6ProtocolNumber:
if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
+
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ cm.HasIPv6PacketInfo = true
+ cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: p.packetInfo.NIC,
+ Addr: p.packetInfo.DestinationAddr,
+ }
+ }
default:
panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
}
- if e.ops.GetReceivePacketInfo() {
- cm.HasIPPacketInfo = true
- cm.PacketInfo = p.packetInfo
- }
+
if e.ops.GetReceiveOriginalDstAddress() {
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
@@ -283,7 +292,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-// +checklocks:e.mu
+// +checklocksread:e.mu
func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.net.State() {
case transport.DatagramEndpointStateInitial:
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 554ce1de4..b3199489c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -313,6 +314,9 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
Clock: &faketime.NullClock{},
}
s := stack.New(options)
+ // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus
+ // never allows ICMP messages.
+ s.SetICMPLimit(rate.Inf)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -323,12 +327,20 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1357,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
func TestReadIPPacketInfo(t *testing.T) {
tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedLocalAddr tcpip.Address
- expectedDestAddr tcpip.Address
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ checker func(tcpip.NICID) checker.ControlMessagesChecker
}{
{
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedLocalAddr: stackAddr,
- expectedDestAddr: stackAddr,
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ LocalAddr: stackAddr,
+ DestinationAddr: stackAddr,
+ })
+ },
},
{
name: "IPv4 multicast",
proto: header.IPv4ProtocolNumber,
flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastAddr,
- expectedDestAddr: multicastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: multicastAddr,
+ DestinationAddr: multicastAddr,
+ })
+ },
},
{
name: "IPv4 broadcast",
proto: header.IPv4ProtocolNumber,
flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: broadcastAddr,
- expectedDestAddr: broadcastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: broadcastAddr,
+ DestinationAddr: broadcastAddr,
+ })
+ },
},
{
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedLocalAddr: stackV6Addr,
- expectedDestAddr: stackV6Addr,
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: stackV6Addr,
+ })
+ },
},
{
name: "IPv6 multicast",
proto: header.IPv6ProtocolNumber,
flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastV6Addr,
- expectedDestAddr: multicastV6Addr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: multicastV6Addr,
+ })
+ },
},
}
@@ -1437,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- c.ep.SocketOptions().SetReceivePacketInfo(true)
+ switch f := test.flow.netProto(); f {
+ case header.IPv4ProtocolNumber:
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
+ case header.IPv6ProtocolNumber:
+ c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true)
+ default:
+ t.Fatalf("unhandled protocol number = %d", f)
+ }
- testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: 1,
- LocalAddr: test.expectedLocalAddr,
- DestinationAddr: test.expectedDestAddr,
- }))
+ testRead(c, test.flow, test.checker(c.nicID))
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
@@ -2504,8 +2525,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)