summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.buildkite/pipeline.yaml5
-rw-r--r--pkg/safecopy/safecopy.go2
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue.go12
-rw-r--r--pkg/sentry/socket/netstack/netstack.go50
-rw-r--r--pkg/sentry/socket/netstack/stack.go2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go112
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s2
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s2
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go16
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/ip_test.go175
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go28
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go14
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go82
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go68
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go16
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go92
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go33
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go48
-rw-r--r--pkg/tcpip/network/multicast_group_test.go21
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go8
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go16
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go24
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go4
-rw-r--r--pkg/tcpip/stack/forwarding_test.go26
-rw-r--r--pkg/tcpip/stack/ndp_test.go107
-rw-r--r--pkg/tcpip/stack/nic.go4
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/registration.go29
-rw-r--r--pkg/tcpip/stack/stack.go45
-rw-r--r--pkg/tcpip/stack/stack_test.go613
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go18
-rw-r--r--pkg/tcpip/stack/transport_test.go33
-rw-r--r--pkg/tcpip/tcpip.go2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go20
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go72
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go24
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go47
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go20
-rw-r--r--pkg/tcpip/tests/integration/route_test.go38
-rw-r--r--pkg/tcpip/tests/utils/utils.go32
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go8
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD2
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go131
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go137
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go59
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/raw/BUILD2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go430
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go30
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go35
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go20
-rw-r--r--runsc/boot/network.go13
-rw-r--r--runsc/cmd/do.go1
-rw-r--r--runsc/cmd/run.go9
-rw-r--r--test/benchmarks/tcp/tcp_proxy.go8
-rw-r--r--test/syscalls/BUILD5
-rw-r--r--test/syscalls/linux/BUILD44
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc8
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h7
-rw-r--r--test/syscalls/linux/network_namespace.cc19
-rw-r--r--test/syscalls/linux/packet_socket.cc9
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc11
-rw-r--r--test/syscalls/linux/ping_socket.cc10
-rw-r--r--test/syscalls/linux/raw_socket.cc118
-rw-r--r--test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc299
-rw-r--r--test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h31
-rw-r--r--test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc28
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc325
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc4
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc2
-rw-r--r--test/syscalls/linux/udp_socket.cc61
-rw-r--r--test/util/BUILD1
-rw-r--r--test/util/benchmark_main.cc64
-rw-r--r--tools/bazel.mk4
-rw-r--r--tools/show_paths.bzl2
82 files changed, 2410 insertions, 1545 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
index c1b478dc3..adff70fee 100644
--- a/.buildkite/pipeline.yaml
+++ b/.buildkite/pipeline.yaml
@@ -169,6 +169,11 @@ steps:
label: ":mechanical_arm: ARM"
command: make arm-qemu-smoke-test
+ # Build everything.
+ - <<: *common
+ label: ":world_map: Build everything"
+ command: "make build OPTIONS=--build_tag_filters=-nogo TARGETS=//..."
+
# Run basic benchmarks smoke tests (no upload).
- <<: *common
label: ":fire: Benchmarks smoke test"
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index 5e6f903ff..a9711e63d 100644
--- a/pkg/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
@@ -83,7 +83,7 @@ var (
// when we get a SIGSEGV that is not interesting to us.
savedSigSegVHandler uintptr
- // same a above, but for SIGBUS signals.
+ // Same as above, but for SIGBUS signals.
savedSigBusHandler uintptr
)
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go
index 7c459d076..c7c5e41fb 100644
--- a/pkg/sentry/kernel/msgqueue/msgqueue.go
+++ b/pkg/sentry/kernel/msgqueue/msgqueue.go
@@ -129,6 +129,16 @@ type Message struct {
Size uint64
}
+func (m *Message) makeCopy() *Message {
+ new := &Message{
+ Type: m.Type,
+ Size: m.Size,
+ }
+ new.Text = make([]byte, len(m.Text))
+ copy(new.Text, m.Text)
+ return new
+}
+
// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves
// as an abstracted version of kernel.Task. kernel.Task is not directly used to
// prevent circular dependencies.
@@ -455,7 +465,7 @@ func (q *Queue) Copy(mType int64) (*Message, error) {
if msg == nil {
return nil, linuxerr.ENOMSG
}
- return msg, nil
+ return msg.makeCopy(), nil
}
// msgOfType returns the first message with the specified type, nil if no
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 8cf2f29e4..f79bda922 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -419,6 +419,27 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
+// minSockAddrLen returns the minimum length in bytes of a socket address for
+// the socket's family.
+func (s *socketOpsCommon) minSockAddrLen() int {
+ const addressFamilySize = 2
+
+ switch s.family {
+ case linux.AF_UNIX:
+ return addressFamilySize
+ case linux.AF_INET:
+ return sockAddrInetSize
+ case linux.AF_INET6:
+ return sockAddrInet6Size
+ case linux.AF_PACKET:
+ return sockAddrLinkSize
+ case linux.AF_UNSPEC:
+ return addressFamilySize
+ default:
+ panic(fmt.Sprintf("s.family unrecognized = %d", s.family))
+ }
+}
+
func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
@@ -545,16 +566,21 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.Endpoint.Readiness(mask)
}
-func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
+// checkFamily returns true iff the specified address family may be used with
+// the socket.
+//
+// If exact is true, then the specified address family must be an exact match
+// with the socket's family.
+func (s *socketOpsCommon) checkFamily(family uint16, exact bool) bool {
if family == uint16(s.family) {
- return nil
+ return true
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
if !s.Endpoint.SocketOptions().GetV6Only() {
- return nil
+ return true
}
}
- return syserr.ErrInvalidArgument
+ return false
}
// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
@@ -587,8 +613,8 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
return syserr.TranslateNetstackError(err)
}
- if err := s.checkFamily(family, false /* exact */); err != nil {
- return err
+ if !s.checkFamily(family, false /* exact */) {
+ return syserr.ErrInvalidArgument
}
addr = s.mapFamily(addr, family)
@@ -655,14 +681,18 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
}
} else {
+ if s.minSockAddrLen() > len(sockaddr) {
+ return syserr.ErrInvalidArgument
+ }
+
var err *syserr.Error
addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
- if err = s.checkFamily(family, true /* exact */); err != nil {
- return err
+ if !s.checkFamily(family, true /* exact */) {
+ return syserr.ErrAddressFamilyNotSupported
}
addr = s.mapFamily(addr, family)
@@ -2872,8 +2902,8 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
if err != nil {
return 0, err
}
- if err := s.checkFamily(family, false /* exact */); err != nil {
- return 0, err
+ if !s.checkFamily(family, false /* exact */) {
+ return 0, syserr.ErrInvalidArgument
}
addrBuf = s.mapFamily(addrBuf, family)
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 208ab9909..ea199f223 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
// Attach address to interface.
nicID := tcpip.NICID(idx)
- if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 48b24692b..c8460e63c 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
done := make(chan struct{})
@@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
done := make(chan struct{})
fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
@@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %w", NICID, protocolAddr, err)
+ }
l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ return nil, nil, nil, fmt.Errorf("NewListener: %w", err)
}
c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ return nil, nil, nil, fmt.Errorf("DialTCP: %w", err)
}
c2, err = l.Accept()
if err != nil {
l.Close()
c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Accept: %w", err)
}
stop = func() {
@@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
if err := l.Close(); err != nil {
stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Close(): %w", err)
}
return c1, c2, stop, nil
@@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
@@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
time.Sleep(time.Second)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
index 298bad55d..f2c230720 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVQ $0x0, R10 // sigmask parameter which isn't used here
MOVQ $0x10f, AX // SYS_PPOLL
SYSCALL
- CMPQ AX, $0xfffffffffffff001
+ CMPQ AX, $0xfffffffffffff002
JLS ok
MOVQ $-1, n+24(FP)
NEGQ AX
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
index b62888b93..8807586c7 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVD $0x0, R3 // sigmask parameter which isn't used here
MOVD $0x49, R8 // SYS_PPOLL
SVC
- CMP $0xfffffffffffff001, R0
+ CMP $0xfffffffffffff002, R0
BLS ok
MOVD $-1, R1
MOVD R1, n+24(FP)
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index e76fc55b6..87a0b9a62 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -181,7 +181,9 @@ func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip
if e == 0 {
return int(n), nil
}
-
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
+ }
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, nil
@@ -204,6 +206,10 @@ func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpi
return int(n), nil
}
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
+ }
+
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, nil
@@ -228,5 +234,13 @@ func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno)
},
}
_, errno := BlockingPoll(&pevents[0], len(pevents), nil)
+ if errno != 0 {
+ return pevents[0].Revents&unix.POLLIN != 0, errno
+ }
+
+ if pevents[1].Revents&unix.POLLHUP != 0 || pevents[1].Revents&unix.POLLERR != 0 {
+ errno = unix.ECONNRESET
+ }
+
return pevents[0].Revents&unix.POLLIN != 0, errno
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 7b1ff44f4..c0179104a 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -23,8 +23,10 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6515c31e5..e08243547 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -272,7 +272,6 @@ type protocol struct {
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
-func (p *protocol) DefaultPrefixLen() int { return 0 }
func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 5fcbfeaa2..061cc35ae 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext
t.Fatalf("CreateNIC failed: %s", err)
}
- if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %s", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: stackAddr.WithPrefix(),
+ }
+ if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
tc.s.SetRouteTable([]tcpip.Route{{
@@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 771b9173a..87f650661 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "bytes"
"fmt"
"strings"
"testing"
@@ -32,8 +33,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const nicID = 1
@@ -230,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv4.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
@@ -246,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv6.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
@@ -269,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c
}
v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
}
v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
}
return s, e
@@ -710,8 +725,8 @@ func TestReceive(t *testing.T) {
if !ok {
t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
}
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
} else {
ep.DecRef()
}
@@ -882,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -968,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1234,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1301,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name string
protoFactory stack.NetworkProtocolFactory
protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.Address
+ nicAddr tcpip.AddressWithPrefix
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
@@ -1311,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1352,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with IHL too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1376,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1394,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 minimum size",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1430,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
@@ -1475,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options and data across views",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
@@ -1516,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
@@ -1556,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 with extension header",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
@@ -1601,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 minimum size",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1636,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 too small",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1660,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}{
{
name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))),
},
{
name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))),
},
}
@@ -1677,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.protoNum,
+ AddressWithPrefix: test.nicAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
- r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
}
defer r.Release()
@@ -2032,3 +2051,97 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
})
}
}
+
+func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.AddressWithPrefix
+ payloadOffset int
+ }{
+ {
+ name: "IPv4",
+ proto: header.IPv4ProtocolNumber,
+ addr: localIPv4AddrWithPrefix,
+ payloadOffset: header.IPv4MinimumSize,
+ },
+ {
+ name: "IPv6",
+ proto: header.IPv6ProtocolNumber,
+ addr: localIPv6AddrWithPrefix,
+ payloadOffset: 0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ RawFactory: raw.EndpointFactory{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.proto,
+ AddressWithPrefix: test.addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: test.addr.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
+ }
+ defer ep.Close()
+
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.addr.Address,
+ },
+ }
+ data := []byte{1, 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := ep.Write(&r, writeOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
+ }
+
+ // Wait for the endpoint to become readable.
+ <-ch
+
+ var w bytes.Buffer
+ rr, err := ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if err != nil {
+ t.Fatalf("ep.Read(...): %s", err)
+ }
+ if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" {
+ t.Errorf("remote addr mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4bd6f462e..c6576fcbc 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// cycles.
func TestIGMPV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
- addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
@@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) {
// The initial set of IGMP reports that were queued should be sent once an
// address is assigned.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackAddr,
+ PrefixLen: defaultPrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if got := reportStat.Value(); got != 1 {
t.Errorf("got reportStat.Value() = %d, want = 1", got)
@@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e, s, _ := createStack(t, true)
for _, address := range test.stackAddresses {
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: address,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
stats := s.Stats()
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 44c85bdb8..aef789b4c 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -240,7 +240,7 @@ func (e *endpoint) Enable() tcpip.Error {
}
// Create an endpoint to receive broadcast packets on this interface.
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
return err
}
@@ -856,6 +856,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
}
func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
+ pkt.NICID = e.nic.ID()
+
// Raw socket packets are delivered based solely on the transport protocol
// number. We only require that the packet be valid IPv4, and that they not
// be fragmented.
@@ -863,7 +865,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
- pkt.NICID = e.nic.ID()
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -1074,11 +1075,11 @@ func (e *endpoint) Close() {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err == nil {
e.mu.igmp.sendQueuedReports()
}
@@ -1225,11 +1226,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
-// DefaultPrefixLen returns the IPv4 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv4AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 73407be67..e7b5b3ea2 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if err := ep.Connect(randomAddr); err != nil {
t.Errorf("Connect failed: %v", err)
@@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
expectedEmittedPacketCount := 1
@@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
// Default routes for IPv4 so ICMP can find a route to the remote
@@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
for _, f := range test.fragments {
@@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) {
const (
nicID = 1
- addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1
+ addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2
+ addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3
)
// Build and return a UDP header containing payload.
@@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
)
- if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask(header.IPv4Broadcast)
@@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 7c2a3e56b..3b4c235fa 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -225,8 +225,8 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -407,8 +407,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
+ llProtocolAddr0 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
@@ -416,8 +420,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
+ llProtocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr1.WithPrefix(),
+ }
+ if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err)
}
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -690,8 +698,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -883,8 +895,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1065,8 +1081,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1240,8 +1260,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -1415,8 +1439,8 @@ func TestPacketQueing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1669,8 +1693,12 @@ func TestCallsToNeighborCache(t *testing.T) {
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
{
@@ -1704,8 +1732,8 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index b1aec5312..c824e27fa 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1127,11 +1127,12 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
}
func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
+ pkt.NICID = e.nic.ID()
+
// Raw socket packets are delivered based solely on the transport protocol
// number. We only require that the packet be valid IPv6.
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
- pkt.NICID = e.nic.ID()
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1627,12 +1628,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
defer e.mu.Unlock()
- return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+ return e.addAndAcquirePermanentAddressLocked(addr, properties)
}
// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
@@ -1642,8 +1643,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err != nil {
return nil, err
}
@@ -2010,11 +2011,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen returns the IPv6 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index d2a23fd4f..0735ebb23 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -41,12 +41,12 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
- addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02")
+ addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03")
// Tests use the extension header identifier values as uint8 instead of
// header.IPv6ExtensionHeaderIdentifier.
@@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err)
}
// Should receive a packet destined to the solicited node address of
// addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr3.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err)
}
// Should still receive a packet destined to the solicited node address of
@@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
@@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Add a default route so that a return packet knows where to go.
@@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) {
func TestFragmentReassemblyTimeout(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
)
- if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
@@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
@@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index bc9cf6999..3e5c438d3 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Note, we will still expect to send a report for the global address's
// solicited node address from the unspecified address as per RFC 3590
// section 4.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ globalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: globalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err)
}
reportCounter++
if got := reportStat.Value(); got != reportCounter {
@@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Adding a link-local address should send a report for its solicited node
// address and globalMulticastAddr.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ linkLocalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err)
}
if dadResolutionTime != 0 {
reportCounter++
@@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 8837d66d8..938427420 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{
+ PEB: stack.FirstPrimaryEndpoint,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ })
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index f0186c64e..8297a7e10 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
@@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
@@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
@@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
@@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) {
checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
))
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
checkDADMsg()
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 1b96b1fb8..26640b7ee 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
+ addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackIPv4Addr,
+ PrefixLen: defaultIPv4PrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, clock
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 009cab643..05b879543 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -146,8 +146,12 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Add default route.
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index c10b19aa0..a72afadda 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -124,13 +124,13 @@ func main() {
log.Fatalf("Bad IP address: %v", addrName)
}
- var addr tcpip.Address
+ var addrWithPrefix tcpip.AddressWithPrefix
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
- addr = tcpip.Address(parsedAddr.To4())
+ addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix()
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
- addr = tcpip.Address(parsedAddr.To16())
+ addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix()
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", addrName)
@@ -176,11 +176,15 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, proto, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address))))
if err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ae0bb4ace..7e4b5bf74 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We now promote the address.
for i, s := range a.mu.primary {
if s == addrState {
- switch peb {
+ switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
@@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
case NeverPrimaryEndpoint:
a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
@@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// Acquire the address before returning it.
addrState.mu.refs++
- addrState.mu.deprecated = deprecated
- addrState.mu.configType = configType
+ addrState.mu.deprecated = properties.Deprecated
+ addrState.mu.configType = properties.ConfigType
if attemptAddToPrimary {
- switch peb {
+ switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.mu.primary = append(a.mu.primary, addrState)
@@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
a.mu.primary[0] = addrState
}
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
@@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
- ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
- panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 140f146f6..c55f85743 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
}
{
- ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
- t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err)
}
// We don't need the address endpoint.
ep.DecRef()
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index ccb69393b..c2f1f4798 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
- return fwdTestNetDefaultPrefixLen
-}
-
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
@@ -384,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
- if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
// NIC 2 has the link address "b", and added the network address 2.
@@ -397,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
- if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
nic, ok := s.nics[2]
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 4d5431da1..40b33b6b5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Should get the address immediately since we should not have performed
@@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Make sure the address does not resolve before the resolution time has
@@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet
@@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) {
// Attempting to add the address again should not fail if the address's
// state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
})
}
@@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) {
// Add addresses for each NIC.
addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix1,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err)
}
addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix2,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err)
}
expectDADEvent(nicID2, addr2)
addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix3,
+ }
+ if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err)
}
expectDADEvent(nicID3, addr3)
@@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
continue
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: test.addrs[j].Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{}
@@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addr2,
}
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
// Add the address as a static address before SLAAC tries to add it.
- if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err)
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
@@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) {
}
if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a796942ab..ab7e1d859 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -514,7 +514,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -525,7 +525,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return &tcpip.ErrNotSupported{}
}
- addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5cb342f78..c8ad93f29 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
-func (*testIPv6Protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 113baaaae..31b3a554d 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int
const (
// CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
+ // endpoint for new connections with no local address.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
// FirstPrimaryEndpoint indicates the endpoint should be the first
@@ -332,6 +331,19 @@ const (
NeverPrimaryEndpoint
)
+func (peb PrimaryEndpointBehavior) String() string {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return "CanBePrimaryEndpoint"
+ case FirstPrimaryEndpoint:
+ return "FirstPrimaryEndpoint"
+ case NeverPrimaryEndpoint:
+ return "NeverPrimaryEndpoint"
+ default:
+ panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb))
+ }
+}
+
// AddressConfigType is the method used to add an address.
type AddressConfigType int
@@ -351,6 +363,14 @@ const (
AddressConfigSlaacTemp
)
+// AddressProperties contains additional properties that can be configured when
+// adding an address.
+type AddressProperties struct {
+ PEB PrimaryEndpointBehavior
+ ConfigType AddressConfigType
+ Deprecated bool
+}
+
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
@@ -457,7 +477,7 @@ type AddressableEndpoint interface {
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
@@ -685,9 +705,6 @@ type NetworkProtocol interface {
// than this targeted at this protocol.
MinimumPacketSize() int
- // DefaultPrefixLen returns the protocol's default prefix length.
- DefaultPrefixLen() int
-
// ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cb741e540..98867a828 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -916,46 +916,9 @@ type NICStateFlags struct {
Loopback bool
}
-// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
- return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
-// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
- ap := tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: addr,
- }
- return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
-}
-
-// AddProtocolAddress adds a new network-layer protocol address to the
-// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
- return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify
-// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return &tcpip.ErrUnknownProtocol{}
- }
- return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb)
-}
-
-// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
-// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+// AddProtocolAddress adds an address to the specified NIC, possibly with extra
+// properties.
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -964,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return &tcpip.ErrUnknownNICID{}
}
- return nic.addAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, properties)
}
// RemoveAddress removes an existing network-layer address from the specified
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 3089c0ef4..cd4137794 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -234,10 +234,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (*fakeNetworkProtocol) DefaultPrefixLen() int {
- return fakeDefaultPrefixLen
-}
-
func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
return f.packetCount[int(intfAddr)%len(f.packetCount)]
}
@@ -349,12 +345,26 @@ func TestNetworkReceive(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -517,8 +527,15 @@ func TestNetworkSend(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Make sure that the link-layer endpoint received the outbound packet.
@@ -538,12 +555,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -551,12 +582,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -812,8 +857,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
ep2 := channel.New(1, defaultMTU, "")
@@ -821,8 +873,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr2,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
// Set a route table that sends all packets with odd destination
@@ -978,12 +1037,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -991,12 +1064,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -1058,8 +1145,15 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1108,8 +1202,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1242,8 +1343,15 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1270,8 +1378,8 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1310,8 +1418,8 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1453,8 +1561,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
@@ -1510,8 +1625,15 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
@@ -1633,8 +1755,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
- if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
+ if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
@@ -1678,13 +1800,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
t.Fatalf("CreateNIC failed: %s", err)
}
nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1726,7 +1848,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1},
},
rt...,
)
@@ -1808,8 +1930,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: anyAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
@@ -1886,22 +2015,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Add an address and in case of a primary one include a
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
+ properties := stack.AddressProperties{PEB: behavior}
if behavior == stack.CanBePrimaryEndpoint {
protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: addrLen * 8,
- },
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: address.WithPrefix(),
}
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
}
}
@@ -1996,8 +2130,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
PrefixLen: tc.prefixLen,
},
}
- if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
- t.Fatal("AddProtocolAddress failed:", err)
+ if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -2047,33 +2181,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
}
-func TestAddAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
- for _, addrLen := range []int{4, 16} {
- address := addrGen.next(addrLen)
- if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
@@ -2084,96 +2191,43 @@ func TestAddProtocolAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- var addrGen addressGenerator
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
addrLenRange := []int{4, 16}
behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp}
+ deprecatedRange := []bool{false, true}
+ wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange))
var addrGen addressGenerator
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- for _, behavior := range behaviorRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ for _, configType := range configTypeRange {
+ for _, deprecated := range deprecatedRange {
+ address := addrGen.next(addrLen)
+ properties := stack.AddressProperties{
+ PEB: behavior,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err)
+ }
+ wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
+ })
}
- expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
+ verifyAddresses(t, wantAddresses, gotAddresses)
}
func TestCreateNICWithOptions(t *testing.T) {
@@ -2290,8 +2344,15 @@ func TestNICStats(t *testing.T) {
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed: ", err)
}
- if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nic.addr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err)
}
{
@@ -2735,8 +2796,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// be returned by a call to GetMainNICAddress;
// else, it should.
const address1 = tcpip.Address("\x01")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ properties := stack.AddressProperties{PEB: pi}
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err)
}
addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
if err != nil {
@@ -2785,16 +2854,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
const address3 = tcpip.Address("\x03")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
-
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address3,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: ps}
+ if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
@@ -3096,8 +3180,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
for _, a := range test.nicAddrs {
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
- t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: a.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -3203,8 +3291,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// The NIC should have joined addr1's solicited node multicast address.
@@ -3359,8 +3451,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
PrefixLen: 128,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
// Address should be in the list of all addresses.
@@ -3687,8 +3779,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
@@ -3750,8 +3842,8 @@ func TestResolveWith(t *testing.T) {
PrefixLen: 24,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
@@ -3792,8 +3884,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3881,8 +3980,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -3990,44 +4089,44 @@ func TestFindRouteWithForwarding(t *testing.T) {
)
type netCfg struct {
- proto tcpip.NetworkProtocolNumber
- factory stack.NetworkProtocolFactory
- nic1Addr tcpip.Address
- nic2Addr tcpip.Address
- remoteAddr tcpip.Address
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1AddrWithPrefix tcpip.AddressWithPrefix
+ nic2AddrWithPrefix tcpip.AddressWithPrefix
+ remoteAddr tcpip.Address
}
fakeNetCfg := netCfg{
- proto: fakeNetNumber,
- factory: fakeNetFactory,
- nic1Addr: nic1Addr,
- nic2Addr: nic2Addr,
- remoteAddr: remoteAddr,
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen},
+ nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen},
+ remoteAddr: remoteAddr,
}
globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: llAddr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: globalIPv6Addr1,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: llAddr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: globalIPv6Addr1,
}
ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: llAddr1,
- remoteAddr: llAddr2,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: llAddr1.WithPrefix(),
+ remoteAddr: llAddr2,
}
ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
}
tests := []struct {
@@ -4036,8 +4135,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg netCfg
forwardingEnabled bool
- addrNIC tcpip.NICID
- localAddr tcpip.Address
+ addrNIC tcpip.NICID
+ localAddrWithPrefix tcpip.AddressWithPrefix
findRouteErr tcpip.Error
dependentOnForwarding bool
@@ -4047,7 +4146,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4056,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4065,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4074,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4083,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4092,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4101,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4110,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4118,7 +4217,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4126,7 +4225,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4134,7 +4233,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4142,7 +4241,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: true,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4166,7 +4265,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on different NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4174,7 +4273,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4182,7 +4281,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4190,7 +4289,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4198,7 +4297,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4206,7 +4305,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4214,7 +4313,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4222,7 +4321,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4230,7 +4329,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4238,7 +4337,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4246,7 +4345,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4268,12 +4367,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
}
- if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic1AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
- if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic2AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
@@ -4282,20 +4389,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
- r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
- t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
return
}
- if r.LocalAddress() != test.localAddr {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr)
+ if r.LocalAddress() != test.localAddrWithPrefix.Address {
+ t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address)
}
if r.RemoteAddress() != test.netCfg.remoteAddr {
t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr)
@@ -4318,8 +4425,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if !ok {
t.Fatal("packet not sent through ep2")
}
- if pkt.Route.LocalAddress != test.localAddr {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address)
}
if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 45b09110d..cd3a8c25a 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -35,7 +35,7 @@ import (
const (
testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
@@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: testDstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 839178809..655931715 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -357,8 +357,15 @@ func TestTransportReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -428,8 +435,15 @@ func TestTransportControlReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -497,8 +511,15 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 55683b4fb..a9ce148b9 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -19,7 +19,7 @@
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
-// adding network addresses (via calls to Stack.AddAddress()), and
+// adding network addresses (via calls to Stack.AddProtocolAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 92fa6257d..6e1d4720d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) {
addr: utils.RouterNIC2IPv6Addr,
},
} {
- if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index f9ab7d0af..28b49c6be 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -49,10 +49,10 @@ const (
nicName = "nic1"
anotherNicName = "nic2"
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = "\x0a\x00\x00\x01"
- dstAddrV4 = "\x0a\x00\x00\x02"
- srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
+ dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
+ srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
payloadSize = 20
)
@@ -66,8 +66,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: dstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -82,8 +86,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: dstAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -601,11 +609,19 @@ func TestIPTableWritePackets(t *testing.T) {
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: srcAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: srcAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -856,11 +872,19 @@ func TestForwardingHook(t *testing.T) {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1037,22 +1061,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
}
e2 := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 27caa0c28..95ddd8ec3 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
@@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.incomingAddr,
}
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
}
// Set up endpoint through which we will attempt to forward packets.
@@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.outgoingAddr,
}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index b2008f0b2..f33223e79 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
+ }
+ v6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
}
if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ipv4Loopback,
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if test.forwarding {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 2d0a6e6a7..7753e7d6e 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
@@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
// Set the route table so that UDP can find a NIC that is
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index ac3c703d4..422eb8408 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) {
// request/reply packets.
icmpDataOffset = 8
)
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
+ ipv4Loopback := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse4("127.0.0.1"),
+ PrefixLen: 8,
+ }
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
@@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
icmpBuf func(*testing.T) buffer.View
expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
@@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback,
+ localAddr: header.IPv6Loopback.WithPrefix(),
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
@@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr.Address,
+ localAddr: utils.Ipv4Addr,
icmpBuf: ipv4ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr.Address,
+ localAddr: utils.Ipv6Addr,
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ if len(test.localAddr.Address) != 0 {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.netProto,
+ AddressWithPrefix: test.localAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) {
}
defer ep.Close()
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
if err := ep.Connect(connAddr); err != test.expectedConnectErr {
t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
}
@@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) {
if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
- if rr.RemoteAddr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ if rr.RemoteAddr.Addr != test.localAddr.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
}
test.checkLinkEndpoint(t, e)
@@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) {
}
if subTest.addAddress {
- if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
}
- if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
}
}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 2e6ae55ea..947bcc7b1 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -231,29 +231,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
index cc950cbde..729f50e9a 100644
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s
t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.AddRoute(tcpip.Route{
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index d10e3f13a..b1edce39b 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -9,6 +9,7 @@ go_library(
"endpoint_state.go",
],
visibility = [
+ "//pkg/tcpip/transport/raw:__pkg__",
"//pkg/tcpip/transport/udp:__pkg__",
],
deps = [
@@ -32,6 +33,7 @@ go_test(
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index c5b575e1c..3cb821475 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -18,7 +18,6 @@ package network
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,30 +37,41 @@ type Endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- // state holds a transport.DatagramBasedEndpointState.
- //
- // state must be read from/written to atomically.
- state uint32
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ state transport.DatagramEndpointState
+ // +checklocks:mu
+ wasBound bool
+ // +checklocks:mu
info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
- owner tcpip.PacketOwner
- writeShutdown bool
- effectiveNetProto tcpip.NetworkProtocolNumber
- connectedRoute *stack.Route `state:"manual"`
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
ttl uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastTTL uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastAddr tcpip.Address
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastNICID tcpip.NICID
- ipv4TOS uint8
- ipv6TClass uint8
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
}
// +stateify savable
@@ -72,8 +82,11 @@ type multicastMembership struct {
// Init initializes the endpoint.
func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
- if e.multicastMemberships != nil {
- panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
}
switch netProto {
@@ -88,8 +101,7 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
netProto: netProto,
transProto: transProto,
- state: uint32(transport.DatagramEndpointStateInitial),
-
+ state: transport.DatagramEndpointStateInitial,
info: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: transProto,
@@ -106,14 +118,11 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
return e.netProto
}
-// setState sets the state of the endpoint.
-func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
// State returns the state of the endpoint.
func (e *Endpoint) State() transport.DatagramEndpointState {
- return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.state
}
// Close cleans the endpoint's resources and leaves the endpoint in a closed
@@ -122,7 +131,7 @@ func (e *Endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.State() == transport.DatagramEndpointStateClosed {
+ if e.state == transport.DatagramEndpointStateClosed {
return
}
@@ -136,7 +145,7 @@ func (e *Endpoint) Close() {
e.connectedRoute = nil
}
- e.setEndpointState(transport.DatagramEndpointStateClosed)
+ e.state = transport.DatagramEndpointStateClosed
}
// SetOwner sets the owner of transmitted packets.
@@ -217,7 +226,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
}
- if e.State() == transport.DatagramEndpointStateClosed {
+ if e.state == transport.DatagramEndpointStateClosed {
return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
}
@@ -229,7 +238,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
if opts.To == nil {
// If the user doesn't specify a destination, they should have
// connected to another address.
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return WriteContext{}, &tcpip.ErrDestinationRequired{}
}
@@ -248,13 +257,16 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
nicID = e.info.BindNICID
}
+ if nicID == 0 {
+ nicID = e.info.RegisterNICID
+ }
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
+ dst, netProto, err := e.checkV4MappedRLocked(*opts.To)
if err != nil {
return WriteContext{}, err
}
- route, _, err = e.connectRoute(nicID, dst, netProto)
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
if err != nil {
return WriteContext{}, err
}
@@ -289,29 +301,32 @@ func (e *Endpoint) Disconnect() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return
}
// Exclude ephemerally bound endpoints.
- if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" {
+ if e.wasBound {
e.info.ID = stack.TransportEndpointID{
- LocalAddress: e.info.ID.LocalAddress,
+ LocalAddress: e.info.BindAddr,
}
- e.setEndpointState(transport.DatagramEndpointStateBound)
+ e.state = transport.DatagramEndpointStateBound
} else {
e.info.ID = stack.TransportEndpointID{}
- e.setEndpointState(transport.DatagramEndpointStateInitial)
+ e.state = transport.DatagramEndpointStateInitial
}
e.connectedRoute.Release()
e.connectedRoute = nil
}
-// connectRoute establishes a route to the specified interface or the
+// connectRouteRLocked establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
localAddr := e.info.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -356,7 +371,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
defer e.mu.Unlock()
nicID := addr.NIC
- switch e.State() {
+ switch e.state {
case transport.DatagramEndpointStateInitial:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
if e.info.BindNICID == 0 {
@@ -372,12 +387,12 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4MappedRLocked(addr)
if err != nil {
return err
}
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
if err != nil {
return err
}
@@ -386,7 +401,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
LocalAddress: e.info.ID.LocalAddress,
RemoteAddress: r.RemoteAddress(),
}
- if e.State() == transport.DatagramEndpointStateInitial {
+ if e.state == transport.DatagramEndpointStateInitial {
id.LocalAddress = r.LocalAddress()
}
@@ -402,7 +417,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.info.ID = id
e.info.RegisterNICID = nicID
e.effectiveNetProto = netProto
- e.setEndpointState(transport.DatagramEndpointStateConnected)
+ e.state = transport.DatagramEndpointStateConnected
return nil
}
@@ -411,7 +426,7 @@ func (e *Endpoint) Shutdown() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- switch state := e.State(); state {
+ switch state := e.state; state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
@@ -422,9 +437,12 @@ func (e *Endpoint) Shutdown() tcpip.Error {
}
}
-// checkV4MappedLocked determines the effective network protocol and converts
+// checkV4MappedRLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -456,11 +474,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.State() != transport.DatagramEndpointStateInitial {
+ if e.state != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4MappedRLocked(addr)
if err != nil {
return err
}
@@ -477,24 +495,33 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
return err
}
+ e.wasBound = true
+
e.info.ID = stack.TransportEndpointID{
LocalAddress: addr.Addr,
}
- e.info.BindNICID = nicID
+ e.info.BindNICID = addr.NIC
e.info.RegisterNICID = nicID
e.info.BindAddr = addr.Addr
e.effectiveNetProto = netProto
- e.setEndpointState(transport.DatagramEndpointStateBound)
+ e.state = transport.DatagramEndpointStateBound
return nil
}
+// WasBound returns true iff the endpoint was ever bound.
+func (e *Endpoint) WasBound() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.wasBound
+}
+
// GetLocalAddress returns the address that the endpoint is bound to.
func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
e.mu.RLock()
defer e.mu.RUnlock()
addr := e.info.BindAddr
- if e.State() == transport.DatagramEndpointStateConnected {
+ if e.state == transport.DatagramEndpointStateConnected {
addr = e.connectedRoute.LocalAddress()
}
@@ -509,7 +536,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return tcpip.FullAddress{}, false
}
@@ -597,7 +624,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
+ fa, netProto, err := e.checkV4MappedRLocked(fa)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
index 858007156..173197512 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_state.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -35,7 +35,7 @@ func (e *Endpoint) Resume(s *stack.Stack) {
}
}
- switch state := e.State(); state {
+ switch e.state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound:
if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
@@ -51,6 +51,6 @@ func (e *Endpoint) Resume(s *stack.Stack) {
panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
}
default:
- panic(fmt.Sprintf("unhandled state = %s", state))
+ panic(fmt.Sprintf("unhandled state = %s", e.state))
}
}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
index 2c43eb66a..f263a9ea2 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_test.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -15,6 +15,7 @@
package network_test
import (
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -24,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,17 +35,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
-func TestEndpointStateTransitions(t *testing.T) {
- const (
- nicID = 1
- )
+var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+)
- var (
- ipv4NICAddr = testutil.MustParse4("1.2.3.4")
- ipv6NICAddr = testutil.MustParse6("a::1")
- ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
- ipv6RemoteAddr = testutil.MustParse6("b::1")
- )
+func TestEndpointStateTransitions(t *testing.T) {
+ const nicID = 1
data := buffer.View([]byte{1, 2, 4, 5})
v4Checker := func(t *testing.T, b buffer.View) {
@@ -124,11 +124,20 @@ func TestEndpointStateTransitions(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -139,6 +148,7 @@ func TestEndpointStateTransitions(t *testing.T) {
var ops tcpip.SocketOptions
var ep network.Endpoint
ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
if state := ep.State(); state != transport.DatagramEndpointStateInitial {
t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
}
@@ -207,3 +217,102 @@ func TestEndpointStateTransitions(t *testing.T) {
})
}
}
+
+func TestBindNICID(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ bindAddr tcpip.Address
+ unicast bool
+ }{
+ {
+ name: "IPv4 multicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: header.IPv4AllSystems,
+ unicast: false,
+ },
+ {
+ name: "IPv6 multicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ unicast: false,
+ },
+ {
+ name: "IPv4 unicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: ipv4NICAddr,
+ unicast: true,
+ },
+ {
+ name: "IPv6 unicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: ipv6NICAddr,
+ unicast: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, testBindNICID := range []tcpip.NICID{0, nicID} {
+ t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
+ }
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if ep.WasBound() {
+ t.Fatal("got ep.WasBound() = true, want = false")
+ }
+ wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if !ep.WasBound() {
+ t.Error("got ep.WasBound() = false, want = true")
+ }
+ wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
+ wantInfo.BindAddr = bindAddr.Addr
+ wantInfo.BindNICID = bindAddr.NIC
+ if test.unicast {
+ wantInfo.RegisterNICID = nicID
+ } else {
+ wantInfo.RegisterNICID = bindAddr.NIC
+ }
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 0554d2f4a..2c9786175 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -68,31 +68,31 @@ type endpoint struct {
netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
-
- // The following fields are used to manage the receive queue and are
- // protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // The following fields are used to manage the receive queue.
+ rcvMu sync.Mutex `state:"nosave"`
+ // +checklocks:rcvMu
+ rcvList packetList
+ // +checklocks:rcvMu
rcvBufSize int
- rcvClosed bool
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ // +checklocks:rcvMu
+ rcvClosed bool
+ // +checklocks:rcvMu
+ rcvDisabled bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ closed bool
+ // +checklocks:mu
+ bound bool
+ // +checklocks:mu
boundNIC tcpip.NICID
- // lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
- lastError tcpip.Error
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
- // frozen indicates if the packets should be delivered to the endpoint
- // during restore.
- frozen bool
+ // +checklocks:lastErrorMu
+ lastError tcpip.Error
}
// NewEndpoint returns a new packet endpoint.
@@ -414,7 +414,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
rcvBufSize := ep.ops.GetReceiveBufferSize()
- if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
+ if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -491,18 +491,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
-
-// freeze prevents any more packets from being delivered to the endpoint.
-func (ep *endpoint) freeze() {
- ep.mu.Lock()
- ep.frozen = true
- ep.mu.Unlock()
-}
-
-// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
-// new packets to be delivered again.
-func (ep *endpoint) thaw() {
- ep.mu.Lock()
- ep.frozen = false
- ep.mu.Unlock()
-}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 5c688d286..d2768db7b 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -44,12 +44,16 @@ func (p *packet) loadData(data buffer.VectorisedView) {
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- ep.freeze()
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+ ep.rcvDisabled = true
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- ep.thaw()
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
@@ -57,4 +61,8 @@ func (ep *endpoint) afterLoad() {
if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
panic(err)
}
+
+ ep.rcvMu.Lock()
+ ep.rcvDisabled = false
+ ep.rcvMu.Unlock()
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2eab09088..b7e97e218 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 3bf6c0a8f..3040a445b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,7 @@
package raw
import (
+ "fmt"
"io"
"time"
@@ -34,6 +35,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -57,15 +60,19 @@ type rawPacket struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ ops tcpip.SocketOptions
+
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
@@ -74,20 +81,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- connected bool
- bound bool
- // route is the route to a remote network endpoint. It is set via
- // Connect(), and is valid only when conneted is true.
- route *stack.Route `state:"manual"`
- stats tcpip.TransportEndpointStats `state:"nosave"`
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
+ mu sync.RWMutex `state:"nosave"`
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
@@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
- return nil, &tcpip.ErrUnknownProtocol{}
- }
-
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
associated: associated,
}
@@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, transProto, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -137,7 +125,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return nil, err
}
@@ -154,11 +142,17 @@ func (e *endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.closed || !e.associated {
+ if e.net.State() == transport.DatagramEndpointStateClosed {
return
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
+ e.net.Close()
+
+ if !e.associated {
+ return
+ }
+
+ e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e)
e.rcvMu.Lock()
defer e.rcvMu.Unlock()
@@ -170,15 +164,6 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- e.connected = false
-
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- e.closed = true
-
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -186,9 +171,7 @@ func (e *endpoint) Close() {
func (*endpoint) ModerateRecvBuf(int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.mu.Lock()
- defer e.mu.Unlock()
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ netProto := e.net.NetProto()
// We can create, but not write to, unassociated IPv6 endpoints.
- if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ if !e.associated && netProto == header.IPv6ProtocolNumber {
return 0, &tcpip.ErrInvalidOptionValue{}
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
return 0, &tcpip.ErrInvalidOptionValue{}
}
}
@@ -269,79 +253,26 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
- payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- if e.closed {
- return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
- }
-
- // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
- payloadBytes := make([]byte, p.Len())
- if _, err := io.ReadFull(p, payloadBytes); err != nil {
- return nil, nil, nil, &tcpip.ErrBadBuffer{}
- }
-
- // Did the user caller provide a destination? If not, use the connected
- // destination.
- if opts.To == nil {
- // If the user doesn't specify a destination, they should have
- // connected to another address.
- if !e.connected {
- return nil, nil, nil, &tcpip.ErrDestinationRequired{}
- }
-
- e.route.Acquire()
-
- return payloadBytes, e.route, e.owner, nil
- }
-
- // The caller provided a destination. Reject destination address if it
- // goes through a different NIC than the endpoint was bound to.
- nic := opts.To.NIC
- if e.bound && nic != 0 && nic != e.BindNICID {
- return nil, nil, nil, &tcpip.ErrNoRoute{}
- }
-
- // Find the route to the destination. If BindAddress is 0,
- // FindRoute will choose an appropriate source address.
- route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
- if err != nil {
- return nil, nil, nil, err
- }
-
- return payloadBytes, route, e.owner, nil
- }()
+ ctx, err := e.net.AcquireContextForWrite(opts)
if err != nil {
return 0, err
}
- defer route.Release()
+
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength),
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
- pkt.Owner = owner
-
- if e.ops.GetHeaderIncluded() {
- if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
- return 0, err
- }
- return int64(len(payloadBytes)), nil
- }
- if err := route.WritePacket(stack.NetworkHeaderParams{
- Protocol: e.TransProto,
- TTL: route.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
+ if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil {
return 0, err
}
+
return int64(len(payloadBytes)), nil
}
@@ -352,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ netProto := e.net.NetProto()
+
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
return &tcpip.ErrAddressFamilyNotSupported{}
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.closed {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- nic := addr.NIC
- if e.bound {
- if e.BindNICID == 0 {
- // If we're bound, but not to a specific NIC, the NIC
- // in addr will be used. Nothing to do here.
- } else if addr.NIC == 0 {
- // If we're bound to a specific NIC, but addr doesn't
- // specify a NIC, use the bound NIC.
- nic = e.BindNICID
- } else if addr.NIC != e.BindNICID {
- // We're bound and addr specifies a NIC. They must be
- // the same.
- return &tcpip.ErrInvalidEndpointState{}
- }
- }
-
- // Find a route to the destination.
- route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false)
- if err != nil {
- return err
- }
-
- if e.associated {
- // Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- route.Release()
- return err
+ return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ if e.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ return err
+ }
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = nic
- }
- if e.route != nil {
- // If the endpoint was previously connected then release any previous route.
- e.route.Release()
- }
- e.route = route
- e.connected = true
-
- return nil
+ return nil
+ })
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if !e.connected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return &tcpip.ErrNotConnected{}
}
return nil
@@ -429,46 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // If a local address was specified, verify that it's valid.
- if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
+ return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error {
+ if !e.associated {
+ return nil
+ }
- if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = addr.NIC
- e.BindNICID = addr.NIC
- }
-
- e.BindAddr = addr.Addr
- e.bound = true
-
- return nil
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
+ return nil
+ })
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- addr := e.BindAddr
- if e.connected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- // Linux returns the protocol in the port field.
- Port: uint16(e.TransProto),
- }, nil
+ a := e.net.GetLocalAddress()
+ // Linux returns the protocol in the port field.
+ a.Port = uint16(e.transProto)
+ return a, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -501,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
default:
- return &tcpip.ErrUnknownProtocolOption{}
+ return e.net.SetSockOpt(opt)
}
}
-func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ return e.net.SetSockOptInt(opt, v)
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -528,103 +402,108 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
return v, nil
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- e.mu.RLock()
- e.rcvMu.Lock()
+ notifyReadableEvents := func() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return false
+ }
- // Drop the packet if our buffer is currently full or if this is an unassociated
- // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
- // See: https://man7.org/linux/man-pages/man7/raw.7.html
- //
- // An IPPROTO_RAW socket is send only. If you really want to receive
- // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
- // Note that packet sockets don't reassemble IP fragments, unlike raw
- // sockets.
- if e.rcvClosed || !e.associated {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ClosedReceiver.Increment()
- return
- }
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return false
+ }
- rcvBufSize := e.ops.GetReceiveBufferSize()
- if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
- return
- }
+ srcAddr := pkt.Network().SourceAddress()
+ info := e.net.Info()
- remoteAddr := pkt.Network().SourceAddress()
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if info.ID.RemoteAddress != srcAddr {
+ return false
+ }
- if e.bound {
- // If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
- // If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
+ // Connected sockets may also have been bound to a specific
+ // address/NIC.
+ fallthrough
+ case transport.DatagramEndpointStateBound:
+ // If bound to a NIC, only accept data for that NIC.
+ if info.BindNICID != 0 && info.BindNICID != pkt.NICID {
+ return false
+ }
+
+ // If bound to an address, only accept data for that address.
+ if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ return false
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- }
- // If connected, only accept packets from the remote address we
- // connected to.
- if e.connected && e.route.RemoteAddress() != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
+ wasEmpty := e.rcvBufSize == 0
- wasEmpty := e.rcvBufSize == 0
+ // Push new packet into receive list and increment the buffer size.
+ packet := &rawPacket{
+ senderAddr: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: srcAddr,
+ },
+ }
- // Push new packet into receive list and increment the buffer size.
- packet := &rawPacket{
- senderAddr: tcpip.FullAddress{
- NIC: pkt.NICID,
- Addr: remoteAddr,
- },
- }
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
+ // overlapping slices.
+ var combinedVV buffer.VectorisedView
+ if info.NetProto == header.IPv4ProtocolNumber {
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(network)+len(transport))
+ headers = append(headers, network...)
+ headers = append(headers, transport...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
+ combinedVV.Append(pkt.Data().ExtractVV())
+ packet.data = combinedVV
+ packet.receivedAt = e.stack.Clock().Now()
- // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
- // We copy headers' underlying bytes because pkt.*Header may point to
- // the middle of a slice, and another struct may point to the "outer"
- // slice. Save/restore doesn't support overlapping slices and will fail.
- //
- // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
- // overlapping slices.
- var combinedVV buffer.VectorisedView
- if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
- combinedVV = headers.ToVectorisedView()
- } else {
- combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- packet.receivedAt = e.stack.Clock().Now()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.stats.PacketsReceived.Increment()
- e.rcvList.PushBack(packet)
- e.rcvBufSize += packet.data.Size()
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stats.PacketsReceived.Increment()
- // Notify waiters that there's data to be read.
- if wasEmpty {
+ // Notify waiters that there is data to be read now.
+ return wasEmpty
+ }()
+
+ if notifyReadableEvents {
e.waiterQueue.Notify(waiter.ReadableEvents)
}
}
@@ -636,10 +515,7 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ ret := e.net.Info()
return &ret
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 39669b445..e74713064 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,6 +15,7 @@
package raw
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.net.Resume(s)
+
e.thaw()
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // If the endpoint is connected, re-connect.
- if e.connected {
- var err tcpip.Error
- // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
- // remote address. We used to pass e.remote.RemoteAddress which was
- // effectively the empty address but since moving e.route to hold a pointer
- // to a route instead of the route by value, we pass the empty address
- // directly. Obviously this was always wrong since we should provide the
- // remote address we were connected to, to properly restore the route.
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
- if err != nil {
- panic(err)
- }
- }
-
- // If the endpoint is bound, re-bind.
- if e.bound {
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
if e.associated {
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- panic(err)
+ netProto := e.net.NetProto()
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
}
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index bc8708a5b..58817371e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1382,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) {
if err := s.CreateNIC(id, ep); err != nil {
t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
- t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: id},
@@ -2145,12 +2149,15 @@ func TestSmallReceiveBufferReadiness(t *testing.T) {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x7f\x00\x00\x01"),
- PrefixLen: 8,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
}
{
@@ -4954,13 +4961,17 @@ func makeStack() (*stack.Stack, tcpip.Error) {
}
for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- address tcpip.Address
+ number tcpip.NetworkProtocolNumber
+ addrWithPrefix tcpip.AddressWithPrefix
}{
- {ipv4.ProtocolNumber, context.StackAddr},
- {ipv6.ProtocolNumber, context.StackV6Addr},
+ {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
+ {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
} {
- if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ct.number,
+ AddressWithPrefix: ct.addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
return nil, err
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 6e55a7a32..88bb99354 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -243,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: StackAddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
@@ -257,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: StackV6AddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f171a16f8..4255457f9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -547,7 +547,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
info := e.net.Info()
info.ID.LocalPort = e.localPort
info.ID.RemotePort = e.remotePort
- if info.BindNICID != 0 || info.ID.LocalAddress == "" {
+ if e.net.WasBound() {
var err tcpip.Error
id = stack.TransportEndpointID{
LocalPort: info.ID.LocalPort,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 554ce1de4..2d15830a7 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -323,12 +323,20 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -2504,8 +2512,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 2257ecca7..9fb3ebd95 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -282,12 +282,15 @@ func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkE
for _, addr := range addrs {
proto, tcpipAddr := ipToAddressAndProto(addr.Address)
- ap := tcpip.AddressWithPrefix{
- Address: tcpipAddr,
- PrefixLen: addr.PrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpipAddr,
+ PrefixLen: addr.PrefixLen,
+ },
}
- if err := n.Stack.AddAddressWithPrefix(id, proto, ap); err != nil {
- return fmt.Errorf("AddAddress(%v, %v, %v) failed: %v", id, proto, tcpipAddr, err)
+ if err := n.Stack.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ return fmt.Errorf("AddProtocolAddress(%d, %+v, {}) failed: %s", id, protocolAddr, err)
}
}
return nil
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 6cf76f644..4eb5a96f1 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -130,7 +130,6 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
if conf.Network == config.NetworkNone {
addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace})
-
} else if conf.Rootless {
if conf.Network == config.NetworkSandbox {
c.notifyUser("*** Warning: sandbox network isn't supported with --rootless, switching to host ***")
diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go
index 722181aff..da11c9d06 100644
--- a/runsc/cmd/run.go
+++ b/runsc/cmd/run.go
@@ -68,7 +68,14 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
waitStatus := args[1].(*unix.WaitStatus)
if conf.Rootless {
- return Errorf("Rootless mode not supported with %q", r.Name())
+ if conf.Network == config.NetworkSandbox {
+ return Errorf("sandbox network isn't supported with --rootless, use --network=none or --network=host")
+ }
+
+ if err := specutils.MaybeRunAsRoot(); err != nil {
+ return Errorf("Error executing inside namespace: %v", err)
+ }
+ // Execution will continue here if no more capabilities are needed...
}
bundleDir := r.bundleDir
diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go
index be74e4d4a..bc0d8fd38 100644
--- a/test/benchmarks/tcp/tcp_proxy.go
+++ b/test/benchmarks/tcp/tcp_proxy.go
@@ -208,8 +208,12 @@ func newNetstackImpl(mode string) (impl, error) {
if err := s.CreateNIC(nicID, fifo.New(ep, runtime.GOMAXPROCS(0), 1000)); err != nil {
return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil {
- return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: parsedAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("error adding IP address %+v to %q: %w", protocolAddr, *iface, err)
}
subnet, err := tcpip.NewSubnet(parsedDest, parsedMask)
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 16c451786..3b55112e6 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -707,6 +707,11 @@ syscall_test(
syscall_test(
size = "medium",
+ test = "//test/syscalls/linux:socket_ipv4_datagram_based_socket_unbound_loopback_test",
+)
+
+syscall_test(
+ size = "medium",
test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index b06b3d233..a4efd9486 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -26,6 +26,7 @@ exports_files(
"socket_ip_udp_loopback_blocking.cc",
"socket_ip_udp_loopback_nonblock.cc",
"socket_ip_unbound.cc",
+ "socket_ipv4_datagram_based_socket_unbound_loopback.cc",
"socket_ipv4_udp_unbound_external_networking_test.cc",
"socket_ipv6_udp_unbound_external_networking_test.cc",
"socket_ipv4_udp_unbound_loopback.cc",
@@ -1446,6 +1447,7 @@ cc_binary(
defines = select_system(),
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:cleanup",
@@ -1465,6 +1467,7 @@ cc_binary(
srcs = ["packet_socket.cc"],
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:cleanup",
@@ -2532,6 +2535,24 @@ cc_library(
)
cc_library(
+ name = "socket_ipv4_datagram_based_socket_unbound_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_datagram_based_socket_unbound.cc",
+ ],
+ hdrs = [
+ "socket_ipv4_datagram_based_socket_unbound.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:socket_util",
+ gtest,
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "socket_ipv4_udp_unbound_test_cases",
testonly = 1,
srcs = [
@@ -2545,9 +2566,6 @@ cc_library(
"//test/util:socket_util",
"@com_google_absl//absl/memory",
gtest,
- "//test/util:posix_error",
- "//test/util:save_util",
- "//test/util:test_util",
],
alwayslink = 1,
)
@@ -2954,9 +2972,21 @@ cc_binary(
deps = [
":ip_socket_test_util",
":socket_ipv4_udp_unbound_test_cases",
- "//test/util:socket_util",
"//test/util:test_main",
- "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_ipv4_datagram_based_socket_unbound_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ipv4_datagram_based_socket_unbound_loopback.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ipv4_datagram_based_socket_unbound_test_cases",
+ "//test/util:test_main",
],
)
@@ -4132,12 +4162,10 @@ cc_binary(
srcs = ["network_namespace.cc"],
linkstatic = 1,
deps = [
- "//test/util:socket_util",
gtest,
+ ":ip_socket_test_util",
"//test/util:capability_util",
- "//test/util:posix_error",
"//test/util:test_main",
- "//test/util:test_util",
"//test/util:thread_util",
],
)
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index e90a7e411..a1216d23f 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -156,6 +156,14 @@ SocketKind ICMPv6UnboundSocket(int type) {
UnboundSocketCreator(AF_INET6, type | SOCK_DGRAM, IPPROTO_ICMPV6)};
}
+SocketKind IPv4RawUDPUnboundSocket(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv4 Raw UDP socket");
+ return SocketKind{
+ description, AF_INET, type | SOCK_RAW, IPPROTO_UDP,
+ UnboundSocketCreator(AF_INET, type | SOCK_RAW, IPPROTO_UDP)};
+}
+
SocketKind IPv4UDPUnboundSocket(int type) {
std::string description =
absl::StrCat(DescribeSocketType(type), "IPv4 UDP socket");
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 8f26f1cd0..556838356 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -35,6 +35,9 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr);
// InterfaceIndex returns the index of the named interface.
PosixErrorOr<int> InterfaceIndex(std::string name);
+// GetLoopbackIndex returns the index of the loopback interface.
+inline PosixErrorOr<int> GetLoopbackIndex() { return InterfaceIndex("lo"); }
+
// IPv6TCPAcceptBindSocketPair returns a SocketPairKind that represents
// SocketPairs created with bind() and accept() syscalls with AF_INET6 and the
// given type bound to the IPv6 loopback.
@@ -92,6 +95,10 @@ SocketKind ICMPUnboundSocket(int type);
// created with AF_INET6, SOCK_DGRAM, IPPROTO_ICMPV6, and the given type.
SocketKind ICMPv6UnboundSocket(int type);
+// IPv4RawUDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_RAW, IPPROTO_UDP, and the given type.
+SocketKind IPv4RawUDPUnboundSocket(int type);
+
// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
// created with AF_INET, SOCK_DGRAM, IPPROTO_UDP, and the given type.
SocketKind IPv4UDPUnboundSocket(int type);
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
index 1984feedd..aa9e1bc04 100644
--- a/test/syscalls/linux/network_namespace.cc
+++ b/test/syscalls/linux/network_namespace.cc
@@ -12,18 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <net/if.h>
-#include <sched.h>
-#include <sys/ioctl.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-
-#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/util/capability_util.h"
-#include "test/util/posix_error.h"
-#include "test/util/socket_util.h"
-#include "test/util/test_util.h"
#include "test/util/thread_util.h"
namespace gvisor {
@@ -37,13 +28,7 @@ TEST(NetworkNamespaceTest, LoopbackExists) {
ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0));
// TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
- // Check loopback device exists.
- int sock = socket(AF_INET, SOCK_DGRAM, 0);
- ASSERT_THAT(sock, SyscallSucceeds());
- struct ifreq ifr;
- strncpy(ifr.ifr_name, "lo", IFNAMSIZ);
- EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds())
- << "lo cannot be found";
+ ASSERT_NE(ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()), 0);
});
}
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index bfa5d179a..9515a48f2 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -29,6 +29,7 @@
#include "gtest/gtest.h"
#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
@@ -156,11 +157,9 @@ void CookedPacketTest::TearDown() {
}
int CookedPacketTest::GetLoopbackIndex() {
- struct ifreq ifr;
- snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
- EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
- EXPECT_NE(ifr.ifr_ifindex, 0);
- return ifr.ifr_ifindex;
+ int v = EXPECT_NO_ERRNO_AND_VALUE(gvisor::testing::GetLoopbackIndex());
+ EXPECT_NE(v, 0);
+ return v;
}
// Receive and verify the message via packet socket on interface.
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index e57c60ffa..aa7182b01 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -14,14 +14,12 @@
#include <arpa/inet.h>
#include <net/ethernet.h>
-#include <net/if.h>
#include <net/if_arp.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/udp.h>
#include <netpacket/packet.h>
#include <poll.h>
-#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
@@ -29,6 +27,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
@@ -156,11 +155,9 @@ void RawPacketTest::TearDown() {
}
int RawPacketTest::GetLoopbackIndex() {
- struct ifreq ifr;
- snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
- EXPECT_THAT(ioctl(s_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
- EXPECT_NE(ifr.ifr_ifindex, 0);
- return ifr.ifr_ifindex;
+ int v = EXPECT_NO_ERRNO_AND_VALUE(gvisor::testing::GetLoopbackIndex());
+ EXPECT_NE(v, 0);
+ return v;
}
// Receive via a packet socket.
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
index 7ec1938bf..684983a4c 100644
--- a/test/syscalls/linux/ping_socket.cc
+++ b/test/syscalls/linux/ping_socket.cc
@@ -155,43 +155,33 @@ std::vector<std::tuple<SocketKind, BindTestCase>> ICMPTestCases() {
.bind_to = V4AddrStr("IPv4UnknownUnicast", "192.168.1.1"),
.want = EADDRNOTAVAIL,
},
- // TODO(gvisor.dev/issue/6021): Remove want_gvisor from all the test
- // cases below once ICMP sockets return EAFNOSUPPORT when binding to
- // IPv6 addresses.
{
.bind_to = V6Any(),
.want = EAFNOSUPPORT,
- .want_gvisor = EINVAL,
},
{
.bind_to = V6Loopback(),
.want = EAFNOSUPPORT,
- .want_gvisor = EINVAL,
},
{
.bind_to = V6Multicast(),
.want = EAFNOSUPPORT,
- .want_gvisor = EINVAL,
},
{
.bind_to = V6MulticastInterfaceLocalAllNodes(),
.want = EAFNOSUPPORT,
- .want_gvisor = EINVAL,
},
{
.bind_to = V6MulticastLinkLocalAllNodes(),
.want = EAFNOSUPPORT,
- .want_gvisor = EINVAL,
},
{
.bind_to = V6MulticastLinkLocalAllRouters(),
.want = EAFNOSUPPORT,
- .want_gvisor = EINVAL,
},
{
.bind_to = V6AddrStr("IPv6UnknownUnicast", "fc00::1"),
.want = EAFNOSUPPORT,
- .want_gvisor = EINVAL,
},
});
}
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 66f0e6ca4..ef1db47ee 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -939,6 +939,124 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
::testing::Values(AF_INET, AF_INET6)));
+void TestRawSocketMaybeBindReceive(bool do_bind) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability()));
+
+ constexpr char payload[] = "abcdefgh";
+
+ const sockaddr_in addr = {
+ .sin_family = AF_INET,
+ .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)},
+ };
+
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ sockaddr_in udp_sock_bind_addr = addr;
+ socklen_t udp_sock_bind_addr_len = sizeof(udp_sock_bind_addr);
+ ASSERT_THAT(bind(udp_sock.get(),
+ reinterpret_cast<const sockaddr*>(&udp_sock_bind_addr),
+ sizeof(udp_sock_bind_addr)),
+ SyscallSucceeds());
+ ASSERT_THAT(getsockname(udp_sock.get(),
+ reinterpret_cast<sockaddr*>(&udp_sock_bind_addr),
+ &udp_sock_bind_addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(udp_sock_bind_addr_len, sizeof(udp_sock_bind_addr));
+
+ FileDescriptor raw_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ auto test_recv = [&](const char* scope, uint32_t expected_destination) {
+ SCOPED_TRACE(scope);
+
+ constexpr int kInfinitePollTimeout = -1;
+ pollfd pfd = {
+ .fd = raw_sock.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfinitePollTimeout),
+ SyscallSucceedsWithValue(1));
+
+ struct ipv4_udp_packet {
+ iphdr ip;
+ udphdr udp;
+ char data[sizeof(payload)];
+
+ // Used to make sure only the required space is used.
+ char unused_space;
+ } ABSL_ATTRIBUTE_PACKED;
+ constexpr size_t kExpectedIPPacketSize =
+ offsetof(ipv4_udp_packet, unused_space);
+
+ // Receive the whole IPv4 packet on the raw socket.
+ ipv4_udp_packet read_raw_packet;
+ sockaddr_in peer;
+ socklen_t peerlen = sizeof(peer);
+ ASSERT_EQ(
+ recvfrom(raw_sock.get(), reinterpret_cast<char*>(&read_raw_packet),
+ sizeof(read_raw_packet), 0 /* flags */,
+ reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ static_cast<ssize_t>(kExpectedIPPacketSize))
+ << strerror(errno);
+ ASSERT_EQ(peerlen, sizeof(peer));
+ EXPECT_EQ(read_raw_packet.ip.version, static_cast<unsigned int>(IPVERSION));
+ // IHL holds the number of header bytes in 4 byte units.
+ EXPECT_EQ(read_raw_packet.ip.ihl, sizeof(read_raw_packet.ip) / 4);
+ EXPECT_EQ(ntohs(read_raw_packet.ip.tot_len), kExpectedIPPacketSize);
+ EXPECT_EQ(ntohs(read_raw_packet.ip.frag_off) & IP_OFFMASK, 0);
+ EXPECT_EQ(read_raw_packet.ip.protocol, SOL_UDP);
+ EXPECT_EQ(ntohl(read_raw_packet.ip.saddr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(read_raw_packet.ip.daddr), expected_destination);
+ EXPECT_EQ(read_raw_packet.udp.source, udp_sock_bind_addr.sin_port);
+ EXPECT_EQ(read_raw_packet.udp.dest, udp_sock_bind_addr.sin_port);
+ EXPECT_EQ(ntohs(read_raw_packet.udp.len),
+ kExpectedIPPacketSize - sizeof(read_raw_packet.ip));
+ for (size_t i = 0; i < sizeof(payload); i++) {
+ EXPECT_EQ(read_raw_packet.data[i], payload[i])
+ << "byte mismatch @ idx=" << i;
+ }
+ EXPECT_EQ(peer.sin_family, AF_INET);
+ EXPECT_EQ(peer.sin_port, 0);
+ EXPECT_EQ(ntohl(peer.sin_addr.s_addr), INADDR_LOOPBACK);
+ };
+
+ if (do_bind) {
+ ASSERT_THAT(bind(raw_sock.get(), reinterpret_cast<const sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ }
+
+ constexpr int kSendToFlags = 0;
+ sockaddr_in different_addr = udp_sock_bind_addr;
+ different_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK + 1);
+ ASSERT_THAT(sendto(udp_sock.get(), payload, sizeof(payload), kSendToFlags,
+ reinterpret_cast<const sockaddr*>(&different_addr),
+ sizeof(different_addr)),
+ SyscallSucceedsWithValue(sizeof(payload)));
+ if (!do_bind) {
+ ASSERT_NO_FATAL_FAILURE(
+ test_recv("different_addr", ntohl(different_addr.sin_addr.s_addr)));
+ }
+ ASSERT_THAT(sendto(udp_sock.get(), payload, sizeof(payload), kSendToFlags,
+ reinterpret_cast<const sockaddr*>(&udp_sock_bind_addr),
+ sizeof(udp_sock_bind_addr)),
+ SyscallSucceedsWithValue(sizeof(payload)));
+ ASSERT_NO_FATAL_FAILURE(
+ test_recv("addr", ntohl(udp_sock_bind_addr.sin_addr.s_addr)));
+}
+
+TEST(RawSocketTest, UnboundReceive) {
+ // Test that a raw socket receives packets destined to any address if it is
+ // not bound to an address.
+ ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(false /* do_bind */));
+}
+
+TEST(RawSocketTest, BindReceive) {
+ // Test that a raw socket only receives packets destined to the address it is
+ // bound to.
+ ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(true /* do_bind */));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc
new file mode 100644
index 000000000..0e3e7057b
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc
@@ -0,0 +1,299 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h"
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/util/capability_util.h"
+
+namespace gvisor {
+namespace testing {
+
+void IPv4DatagramBasedUnboundSocketTest::SetUp() {
+ if (GetParam().type & SOCK_RAW) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability()));
+ }
+}
+
+// Check that dropping a group membership that does not exist fails.
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastInvalidDrop) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Unregister from a membership that we didn't have.
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfZero) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn iface = {};
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfInvalidNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn iface = {};
+ iface.imr_ifindex = -1;
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfInvalidAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreq iface = {};
+ iface.imr_interface.s_addr = inet_addr("255.255.255");
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetShort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Create a valid full-sized request.
+ ip_mreqn iface = {};
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
+
+ // Send an optlen of 1 to check that optlen is enforced.
+ EXPECT_THAT(
+ setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfDefault) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, 0);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfDefaultReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
+ // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
+ // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
+ EXPECT_EQ(size, sizeof(in_addr));
+
+ // getsockopt(IP_MULTICAST_IF) will only return the interface address which
+ // hasn't been set.
+ EXPECT_EQ(get.imr_multiaddr.s_addr, 0);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetAddrGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ in_addr set = {};
+ set.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
+ // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
+ // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
+ EXPECT_EQ(size, sizeof(in_addr));
+ EXPECT_EQ(get.imr_multiaddr.s_addr, set.s_addr);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreq set = {};
+ set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
+ // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
+ // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
+ EXPECT_EQ(size, sizeof(in_addr));
+ EXPECT_EQ(get.imr_multiaddr.s_addr, set.imr_interface.s_addr);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetNicGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn set = {};
+ set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ ip_mreqn get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(in_addr));
+ EXPECT_EQ(get.imr_multiaddr.s_addr, 0);
+ EXPECT_EQ(get.imr_address.s_addr, 0);
+ EXPECT_EQ(get.imr_ifindex, 0);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ in_addr set = {};
+ set.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, set.s_addr);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetReqAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreq set = {};
+ set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, set.imr_interface.s_addr);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn set = {};
+ set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ sizeof(set)),
+ SyscallSucceeds());
+
+ in_addr get = {};
+ socklen_t size = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ SyscallSucceeds());
+ EXPECT_EQ(size, sizeof(get));
+ EXPECT_EQ(get.s_addr, 0);
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, TestJoinGroupNoIf) {
+ // TODO(b/185517803): Fix for native test.
+ SKIP_IF(!IsRunningOnGvisor());
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+TEST_P(IPv4DatagramBasedUnboundSocketTest, TestJoinGroupInvalidIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ip_mreqn group = {};
+ group.imr_address.s_addr = inet_addr("255.255.255");
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// Check that multiple memberships are not allowed on the same socket.
+TEST_P(IPv4DatagramBasedUnboundSocketTest, TestMultipleJoinsOnSingleSocket) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto fd = socket1->get();
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
+
+ EXPECT_THAT(
+ setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h
new file mode 100644
index 000000000..7ab235cad
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_DATAGRAM_BASED_SOCKET_UNBOUND_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_DATAGRAM_BASED_SOCKET_UNBOUND_H_
+
+#include "test/util/socket_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to IPv4 datagram-based sockets.
+class IPv4DatagramBasedUnboundSocketTest : public SimpleSocketTest {
+ void SetUp() override;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_DATAGRAM_BASED_SOCKET_UNBOUND_H_
diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc
new file mode 100644
index 000000000..9ae68e015
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h"
+
+namespace gvisor {
+namespace testing {
+
+INSTANTIATE_TEST_SUITE_P(IPv4Sockets, IPv4DatagramBasedUnboundSocketTest,
+ ::testing::ValuesIn(ApplyVecToVec<SocketKind>(
+ {IPv4UDPUnboundSocket, IPv4RawUDPUnboundSocket},
+ AllBitwiseCombinations(List<int>{
+ 0, SOCK_NONBLOCK}))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index 816d1181c..05773a29c 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -15,8 +15,6 @@
#include "test/syscalls/linux/socket_ipv4_udp_unbound.h"
#include <arpa/inet.h>
-#include <net/if.h>
-#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
@@ -26,10 +24,6 @@
#include "gtest/gtest.h"
#include "absl/memory/memory.h"
#include "test/syscalls/linux/ip_socket_test_util.h"
-#include "test/util/posix_error.h"
-#include "test/util/save_util.h"
-#include "test/util/socket_util.h"
-#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
@@ -140,7 +134,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -238,7 +232,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -326,7 +320,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
// Set the default send interface.
ip_mreqn iface = {};
- iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -346,7 +340,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -437,7 +431,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
// Set the default send interface.
ip_mreqn iface = {};
- iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -457,7 +451,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -548,7 +542,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
// Set the default send interface.
ip_mreqn iface = {};
- iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -568,7 +562,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -657,7 +651,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
// Set the default send interface.
ip_mreqn iface = {};
- iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -677,7 +671,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -770,7 +764,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
// Set the default send interface.
ip_mreqn iface = {};
- iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -794,7 +788,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -819,20 +813,6 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
-// Check that dropping a group membership that does not exist fails.
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastInvalidDrop) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- // Unregister from a membership that we didn't have.
- ip_mreq group = {};
- group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
- sizeof(group)),
- SyscallFailsWithErrno(EADDRNOTAVAIL));
-}
-
// Check that dropping a group membership prevents multicast packets from being
// delivered. Default send address configured by bind and group membership
// interface configured by address.
@@ -917,7 +897,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
// Register and unregister to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -943,260 +923,6 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
PosixErrorIs(EAGAIN, ::testing::_));
}
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreqn iface = {};
- EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
- sizeof(iface)),
- SyscallSucceeds());
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidNic) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreqn iface = {};
- iface.imr_ifindex = -1;
- EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
- sizeof(iface)),
- SyscallFailsWithErrno(EADDRNOTAVAIL));
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidAddr) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreq iface = {};
- iface.imr_interface.s_addr = inet_addr("255.255.255");
- EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
- sizeof(iface)),
- SyscallFailsWithErrno(EADDRNOTAVAIL));
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetShort) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- // Create a valid full-sized request.
- ip_mreqn iface = {};
- iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
-
- // Send an optlen of 1 to check that optlen is enforced.
- EXPECT_THAT(
- setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1),
- SyscallFailsWithErrno(EINVAL));
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefault) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- in_addr get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
- EXPECT_EQ(size, sizeof(get));
- EXPECT_EQ(get.s_addr, 0);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefaultReqn) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreqn get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
-
- // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
- // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
- // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
- EXPECT_EQ(size, sizeof(in_addr));
-
- // getsockopt(IP_MULTICAST_IF) will only return the interface address which
- // hasn't been set.
- EXPECT_EQ(get.imr_multiaddr.s_addr, 0);
- EXPECT_EQ(get.imr_address.s_addr, 0);
- EXPECT_EQ(get.imr_ifindex, 0);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddrGetReqn) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- in_addr set = {};
- set.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
- sizeof(set)),
- SyscallSucceeds());
-
- ip_mreqn get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
-
- // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
- // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
- // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
- EXPECT_EQ(size, sizeof(in_addr));
- EXPECT_EQ(get.imr_multiaddr.s_addr, set.s_addr);
- EXPECT_EQ(get.imr_address.s_addr, 0);
- EXPECT_EQ(get.imr_ifindex, 0);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreq set = {};
- set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
- sizeof(set)),
- SyscallSucceeds());
-
- ip_mreqn get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
-
- // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
- // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr.
- // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr.
- EXPECT_EQ(size, sizeof(in_addr));
- EXPECT_EQ(get.imr_multiaddr.s_addr, set.imr_interface.s_addr);
- EXPECT_EQ(get.imr_address.s_addr, 0);
- EXPECT_EQ(get.imr_ifindex, 0);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNicGetReqn) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreqn set = {};
- set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
- sizeof(set)),
- SyscallSucceeds());
-
- ip_mreqn get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
- EXPECT_EQ(size, sizeof(in_addr));
- EXPECT_EQ(get.imr_multiaddr.s_addr, 0);
- EXPECT_EQ(get.imr_address.s_addr, 0);
- EXPECT_EQ(get.imr_ifindex, 0);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddr) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- in_addr set = {};
- set.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
- sizeof(set)),
- SyscallSucceeds());
-
- in_addr get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
-
- EXPECT_EQ(size, sizeof(get));
- EXPECT_EQ(get.s_addr, set.s_addr);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddr) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreq set = {};
- set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
- sizeof(set)),
- SyscallSucceeds());
-
- in_addr get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
-
- EXPECT_EQ(size, sizeof(get));
- EXPECT_EQ(get.s_addr, set.imr_interface.s_addr);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreqn set = {};
- set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
- sizeof(set)),
- SyscallSucceeds());
-
- in_addr get = {};
- socklen_t size = sizeof(get);
- ASSERT_THAT(
- getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
- SyscallSucceeds());
- EXPECT_EQ(size, sizeof(get));
- EXPECT_EQ(get.s_addr, 0);
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) {
- // TODO(b/185517803): Fix for native test.
- SKIP_IF(!IsRunningOnGvisor());
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreqn group = {};
- group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
- sizeof(group)),
- SyscallFailsWithErrno(ENODEV));
-}
-
-TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupInvalidIf) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
-
- ip_mreqn group = {};
- group.imr_address.s_addr = inet_addr("255.255.255");
- group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
- sizeof(group)),
- SyscallFailsWithErrno(ENODEV));
-}
-
-// Check that multiple memberships are not allowed on the same socket.
-TEST_P(IPv4UDPUnboundSocketTest, TestMultipleJoinsOnSingleSocket) {
- auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- auto fd = socket1->get();
- ip_mreqn group = {};
- group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
-
- EXPECT_THAT(
- setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
- SyscallSucceeds());
-
- EXPECT_THAT(
- setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
- SyscallFailsWithErrno(EADDRINUSE));
-}
-
// Check that two sockets can join the same multicast group at the same time.
TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) {
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -1204,7 +930,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -1419,7 +1145,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -2110,16 +1836,11 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
EXPECT_EQ(cmsg->cmsg_level, level);
EXPECT_EQ(cmsg->cmsg_type, type);
- // Get loopback index.
- ifreq ifr = {};
- absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo");
- ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
- ASSERT_NE(ifr.ifr_ifindex, 0);
-
// Check the data
in_pktinfo received_pktinfo = {};
memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo));
- EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex);
+ EXPECT_EQ(received_pktinfo.ipi_ifindex,
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()));
EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK));
EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK));
}
@@ -2439,7 +2160,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) {
// Register to receive multicast packets.
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
&group, sizeof(group)),
SyscallSucceeds());
@@ -2484,16 +2205,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) {
EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP);
EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO);
- // Get loopback index.
- ifreq ifr = {};
- absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo");
- ASSERT_THAT(ioctl(receiver_socket->get(), SIOCGIFINDEX, &ifr),
- SyscallSucceeds());
- ASSERT_NE(ifr.ifr_ifindex, 0);
-
in_pktinfo received_pktinfo = {};
memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo));
- EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex);
+ EXPECT_EQ(received_pktinfo.ipi_ifindex,
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()));
if (IsRunningOnGvisor()) {
// This should actually be a unicast address assigned to the interface.
//
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
index 00930c544..bef284530 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
@@ -12,12 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <vector>
-
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_ipv4_udp_unbound.h"
-#include "test/util/socket_util.h"
-#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
index 09f070797..d72b6b53f 100644
--- a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
@@ -39,7 +39,7 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
.ipv6mr_multiaddr =
reinterpret_cast<sockaddr_in6*>(&multicast_addr.addr)->sin6_addr,
.ipv6mr_interface = static_cast<decltype(ipv6_mreq::ipv6mr_interface)>(
- ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"))),
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex())),
};
ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
&group_req, sizeof(group_req)),
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index d58b57c8b..0c0a16074 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -602,6 +602,67 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) {
SyscallFailsWithErrno(ENOTCONN));
}
+void ConnectThenDisconnect(const FileDescriptor& sock,
+ const sockaddr* bind_addr,
+ const socklen_t expected_addrlen) {
+ // Connect the bound socket.
+ ASSERT_THAT(connect(sock.get(), bind_addr, expected_addrlen),
+ SyscallSucceeds());
+
+ // Disconnect.
+ {
+ sockaddr_storage unspec = {.ss_family = AF_UNSPEC};
+ ASSERT_THAT(connect(sock.get(), AsSockAddr(&unspec), sizeof(unspec)),
+ SyscallSucceeds());
+ }
+ {
+ // Check that we're not in a bound state.
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ ASSERT_THAT(getsockname(sock.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, expected_addrlen);
+ // Everything should be the zero value except the address family.
+ sockaddr_storage expected = {
+ .ss_family = bind_addr->sa_family,
+ };
+ EXPECT_EQ(memcmp(&expected, &addr, expected_addrlen), 0);
+ }
+
+ {
+ // We are not connected so we have no peer.
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(sock.get(), AsSockAddr(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ }
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterBindToUnspecAndConnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ sockaddr_storage unspec = {.ss_family = AF_UNSPEC};
+ int bind_res = bind(sock_.get(), AsSockAddr(&unspec), sizeof(unspec));
+ if ((!IsRunningOnGvisor() || IsRunningWithHostinet()) &&
+ GetFamily() == AF_INET) {
+ // Linux allows this for undocumented compatibility reasons:
+ // https://github.com/torvalds/linux/commit/29c486df6a208432b370bd4be99ae1369ede28d8.
+ //
+ // TODO(https://gvisor.dev/issue/6575): Match Linux's behaviour.
+ ASSERT_THAT(bind_res, SyscallSucceeds());
+ } else {
+ ASSERT_THAT(bind_res, SyscallFailsWithErrno(EAFNOSUPPORT));
+ }
+
+ ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_));
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectWithoutBind) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_));
+}
+
TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) {
ASSERT_NO_ERRNO(BindAny());
diff --git a/test/util/BUILD b/test/util/BUILD
index 0c151c5a1..5b6bcb32c 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -352,7 +352,6 @@ cc_library(
cc_library(
name = "benchmark_main",
testonly = 1,
- srcs = ["benchmark_main.cc"],
linkstatic = 1,
deps = [
":test_util",
diff --git a/test/util/benchmark_main.cc b/test/util/benchmark_main.cc
deleted file mode 100644
index 2d4af6251..000000000
--- a/test/util/benchmark_main.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "gtest/gtest.h"
-#include "absl/flags/flag.h"
-#include "third_party/benchmark/src/commandlineflags.h"
-#include "test/util/test_util.h"
-
-DECLARE_bool(benchmark_list_internal);
-DECLARE_string(benchmark_filter_internal);
-ABSL_FLAG(bool, benchmark_enable_random_interleaving_internal, false,
- "forward");
-ABSL_FLAG(double, benchmark_min_time_internal, -1.0, "forward");
-ABSL_FLAG(int, benchmark_repetitions_internal, 1, "forward");
-
-// From //third_party/benchmark.
-//
-// These conflict with the internal definitions, but all the benchmark binaries
-// link against the external benchmark library for compatibility with the open
-// source build. We massage the internal-only flags into the external ones, and
-// call the function to actually run all registered external benchmarks.
-namespace benchmark {
-BM_DECLARE_bool(benchmark_list_tests);
-BM_DECLARE_string(benchmark_filter);
-BM_DECLARE_int32(benchmark_repetitions);
-BM_DECLARE_double(benchmark_min_time);
-BM_DECLARE_bool(benchmark_enable_random_interleaving);
-extern size_t RunSpecifiedBenchmarks();
-} // namespace benchmark
-
-using benchmark::FLAGS_benchmark_enable_random_interleaving;
-using benchmark::FLAGS_benchmark_filter;
-using benchmark::FLAGS_benchmark_list_tests;
-using benchmark::FLAGS_benchmark_min_time;
-using benchmark::FLAGS_benchmark_repetitions;
-
-int main(int argc, char** argv) {
- gvisor::testing::TestInit(&argc, &argv);
- absl::SetFlag(&FLAGS_benchmark_list_tests,
- absl::GetFlag(FLAGS_benchmark_list_internal));
- absl::SetFlag(&FLAGS_benchmark_filter,
- absl::GetFlag(FLAGS_benchmark_filter_internal));
- absl::SetFlag(&FLAGS_benchmark_repetitions,
- absl::GetFlag(FLAGS_benchmark_repetitions_internal));
- absl::SetFlag(
- &FLAGS_benchmark_enable_random_interleaving,
- absl::GetFlag(FLAGS_benchmark_enable_random_interleaving_internal));
- absl::SetFlag(&FLAGS_benchmark_min_time,
- absl::GetFlag(FLAGS_benchmark_min_time_internal));
-
- benchmark::RunSpecifiedBenchmarks();
- return 0;
-}
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 4f979bbeb..1444423e4 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -186,8 +186,8 @@ build_paths = \
(set -euo pipefail; \
$(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1)) && \
$(call wrapper,$(BAZEL) cquery $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1) --output=starlark --starlark:file=tools/show_paths.bzl) \
- | xargs -r -n 1 -I {} readlink -f "{}" \
- | xargs -r -n 1 -I {} bash -c 'set -xeuo pipefail; $(2)')
+ | xargs -r -n 1 -I {} bash -c 'test -e "{}" || exit 0; readlink -f "{}"' \
+ | xargs -r -n 1 -I {} bash -c 'set -euo pipefail; $(2)')
clean = $(call header,CLEAN) && $(call wrapper,$(BAZEL) clean)
build = $(call header,BUILD $(1)) && $(call build_paths,$(1),echo {})
diff --git a/tools/show_paths.bzl b/tools/show_paths.bzl
index ba78d3494..f0126ac7b 100644
--- a/tools/show_paths.bzl
+++ b/tools/show_paths.bzl
@@ -2,6 +2,8 @@
def format(target):
provider_map = providers(target)
+ if not provider_map:
+ return ""
outputs = dict()
# Try to resolve in order.