diff options
Diffstat (limited to 'pkg')
57 files changed, 1755 insertions, 1112 deletions
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go index 5e6f903ff..a9711e63d 100644 --- a/pkg/safecopy/safecopy.go +++ b/pkg/safecopy/safecopy.go @@ -83,7 +83,7 @@ var ( // when we get a SIGSEGV that is not interesting to us. savedSigSegVHandler uintptr - // same a above, but for SIGBUS signals. + // Same as above, but for SIGBUS signals. savedSigBusHandler uintptr ) diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go index 7c459d076..c7c5e41fb 100644 --- a/pkg/sentry/kernel/msgqueue/msgqueue.go +++ b/pkg/sentry/kernel/msgqueue/msgqueue.go @@ -129,6 +129,16 @@ type Message struct { Size uint64 } +func (m *Message) makeCopy() *Message { + new := &Message{ + Type: m.Type, + Size: m.Size, + } + new.Text = make([]byte, len(m.Text)) + copy(new.Text, m.Text) + return new +} + // Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves // as an abstracted version of kernel.Task. kernel.Task is not directly used to // prevent circular dependencies. @@ -455,7 +465,7 @@ func (q *Queue) Copy(mType int64) (*Message, error) { if msg == nil { return nil, linuxerr.ENOMSG } - return msg, nil + return msg.makeCopy(), nil } // msgOfType returns the first message with the specified type, nil if no diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8cf2f29e4..f79bda922 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -419,6 +419,27 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } +// minSockAddrLen returns the minimum length in bytes of a socket address for +// the socket's family. +func (s *socketOpsCommon) minSockAddrLen() int { + const addressFamilySize = 2 + + switch s.family { + case linux.AF_UNIX: + return addressFamilySize + case linux.AF_INET: + return sockAddrInetSize + case linux.AF_INET6: + return sockAddrInet6Size + case linux.AF_PACKET: + return sockAddrLinkSize + case linux.AF_UNSPEC: + return addressFamilySize + default: + panic(fmt.Sprintf("s.family unrecognized = %d", s.family)) + } +} + func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -545,16 +566,21 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.Endpoint.Readiness(mask) } -func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { +// checkFamily returns true iff the specified address family may be used with +// the socket. +// +// If exact is true, then the specified address family must be an exact match +// with the socket's family. +func (s *socketOpsCommon) checkFamily(family uint16, exact bool) bool { if family == uint16(s.family) { - return nil + return true } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { if !s.Endpoint.SocketOptions().GetV6Only() { - return nil + return true } } - return syserr.ErrInvalidArgument + return false } // mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the @@ -587,8 +613,8 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool return syserr.TranslateNetstackError(err) } - if err := s.checkFamily(family, false /* exact */); err != nil { - return err + if !s.checkFamily(family, false /* exact */) { + return syserr.ErrInvalidArgument } addr = s.mapFamily(addr, family) @@ -655,14 +681,18 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), } } else { + if s.minSockAddrLen() > len(sockaddr) { + return syserr.ErrInvalidArgument + } + var err *syserr.Error addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } - if err = s.checkFamily(family, true /* exact */); err != nil { - return err + if !s.checkFamily(family, true /* exact */) { + return syserr.ErrAddressFamilyNotSupported } addr = s.mapFamily(addr, family) @@ -2872,8 +2902,8 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if err != nil { return 0, err } - if err := s.checkFamily(family, false /* exact */); err != nil { - return 0, err + if !s.checkFamily(family, false /* exact */) { + return 0, syserr.ErrInvalidArgument } addrBuf = s.mapFamily(addrBuf, family) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 208ab9909..ea199f223 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // Attach address to interface. nicID := tcpip.NICID(idx) - if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { return syserr.TranslateNetstackError(err).ToError() } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 48b24692b..c8460e63c 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } done := make(chan struct{}) @@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } done := make(chan struct{}) fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { @@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { @@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) if err != nil { @@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %w", NICID, protocolAddr, err) + } l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) + return nil, nil, nil, fmt.Errorf("NewListener: %w", err) } c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) if err != nil { l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) + return nil, nil, nil, fmt.Errorf("DialTCP: %w", err) } c2, err = l.Accept() if err != nil { l.Close() c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) + return nil, nil, nil, fmt.Errorf("l.Accept: %w", err) } stop = func() { @@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { if err := l.Close(); err != nil { stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) + return nil, nil, nil, fmt.Errorf("l.Close(): %w", err) } return c1, c2, stop, nil @@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { time.Sleep(time.Second) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s index 298bad55d..f2c230720 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s +++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s @@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 MOVQ $0x0, R10 // sigmask parameter which isn't used here MOVQ $0x10f, AX // SYS_PPOLL SYSCALL - CMPQ AX, $0xfffffffffffff001 + CMPQ AX, $0xfffffffffffff002 JLS ok MOVQ $-1, n+24(FP) NEGQ AX diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s index b62888b93..8807586c7 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s +++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s @@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 MOVD $0x0, R3 // sigmask parameter which isn't used here MOVD $0x49, R8 // SYS_PPOLL SVC - CMP $0xfffffffffffff001, R0 + CMP $0xfffffffffffff002, R0 BLS ok MOVD $-1, R1 MOVD R1, n+24(FP) diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index e76fc55b6..87a0b9a62 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -181,7 +181,9 @@ func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip if e == 0 { return int(n), nil } - + if e != 0 && e != unix.EWOULDBLOCK { + return 0, TranslateErrno(e) + } stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) if stopped { return -1, nil @@ -204,6 +206,10 @@ func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpi return int(n), nil } + if e != 0 && e != unix.EWOULDBLOCK { + return 0, TranslateErrno(e) + } + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) if stopped { return -1, nil @@ -228,5 +234,13 @@ func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) }, } _, errno := BlockingPoll(&pevents[0], len(pevents), nil) + if errno != 0 { + return pevents[0].Revents&unix.POLLIN != 0, errno + } + + if pevents[1].Revents&unix.POLLHUP != 0 || pevents[1].Revents&unix.POLLERR != 0 { + errno = unix.ECONNRESET + } + return pevents[0].Revents&unix.POLLIN != 0, errno } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 7b1ff44f4..c0179104a 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -23,8 +23,10 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6515c31e5..e08243547 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -272,7 +272,6 @@ type protocol struct { func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } func (p *protocol) MinimumPacketSize() int { return header.ARPSize } -func (p *protocol) DefaultPrefixLen() int { return 0 } func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 5fcbfeaa2..061cc35ae 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext t.Fatalf("CreateNIC failed: %s", err) } - if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %s", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: stackAddr.WithPrefix(), + } + if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } tc.s.SetRouteTable([]tcpip.Route{{ @@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 771b9173a..87f650661 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "bytes" "fmt" "strings" "testing" @@ -32,8 +33,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const nicID = 1 @@ -230,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv4.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, @@ -246,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv6.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, @@ -269,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c } v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) + if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err) } v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) + if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err) } return s, e @@ -710,8 +725,8 @@ func TestReceive(t *testing.T) { if !ok { t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) } - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err) } else { ep.DecRef() } @@ -882,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -968,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1234,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv6Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1301,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name string protoFactory stack.NetworkProtocolFactory protoNum tcpip.NetworkProtocolNumber - nicAddr tcpip.Address + nicAddr tcpip.AddressWithPrefix remoteAddr tcpip.Address pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) @@ -1311,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1352,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with IHL too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1376,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1394,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 minimum size", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1430,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) @@ -1475,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options and data across views", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) @@ -1516,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) @@ -1556,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 with extension header", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) @@ -1601,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 minimum size", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1636,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 too small", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1660,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }{ { name: "unspecified source", - srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))), }, { name: "random source", - srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))), }, } @@ -1677,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.protoNum, + AddressWithPrefix: test.nicAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) - r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err) } defer r.Release() @@ -2032,3 +2051,97 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { }) } } + +func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.AddressWithPrefix + payloadOffset int + }{ + { + name: "IPv4", + proto: header.IPv4ProtocolNumber, + addr: localIPv4AddrWithPrefix, + payloadOffset: header.IPv4MinimumSize, + }, + { + name: "IPv6", + proto: header.IPv6ProtocolNumber, + addr: localIPv6AddrWithPrefix, + payloadOffset: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + RawFactory: raw.EndpointFactory{}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.proto, + AddressWithPrefix: test.addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.addr.Subnet(), + NIC: nicID, + }, + }) + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) + } + defer ep.Close() + + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.addr.Address, + }, + } + data := []byte{1, 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := ep.Write(&r, writeOpts); err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) + } + + // Wait for the endpoint to become readable. + <-ch + + var w bytes.Buffer + rr, err := ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if err != nil { + t.Fatalf("ep.Read(...): %s", err) + } + if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" { + t.Errorf("payload mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" { + t.Errorf("remote addr mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 4bd6f462e..c6576fcbc 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // cycles. func TestIGMPV1Present(t *testing.T) { e, s, clock := createStack(t, true) - addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength} - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { @@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) { // The initial set of IGMP reports that were queued should be sent once an // address is assigned. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackAddr, + PrefixLen: defaultPrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if got := reportStat.Value(); got != 1 { t.Errorf("got reportStat.Value() = %d, want = 1", got) @@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) { t.Run(test.name, func(t *testing.T) { e, s, _ := createStack(t, true) for _, address := range test.stackAddresses { - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: address, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } stats := s.Stats() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 44c85bdb8..aef789b4c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -240,7 +240,7 @@ func (e *endpoint) Enable() tcpip.Error { } // Create an endpoint to receive broadcast packets on this interface. - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { return err } @@ -856,6 +856,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum } func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { + pkt.NICID = e.nic.ID() + // Raw socket packets are delivered based solely on the transport protocol // number. We only require that the packet be valid IPv4, and that they not // be fragmented. @@ -863,7 +865,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } - pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -1074,11 +1075,11 @@ func (e *endpoint) Close() { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err == nil { e.mu.igmp.sendQueuedReports() } @@ -1225,11 +1226,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv4MinimumSize } -// DefaultPrefixLen returns the IPv4 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv4AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 73407be67..e7b5b3ea2 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) { defer ep.Close() // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if err := ep.Connect(randomAddr); err != nil { t.Errorf("Connect failed: %v", err) @@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err) } expectedEmittedPacketCount := 1 @@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } // Default routes for IPv4 so ICMP can find a route to the remote @@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } for _, f := range test.fragments { @@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) { const ( nicID = 1 - addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 + addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1 + addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2 + addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3 ) // Build and return a UDP header containing payload. @@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" + src = tcpip.Address("\x10\x00\x00\x01") + dst = tcpip.Address("\x10\x00\x00\x02") ) - if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask(header.IPv4Broadcast) @@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 7c2a3e56b..3b4c235fa 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -225,8 +225,8 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -407,8 +407,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } - if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) + llProtocolAddr0 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err) } c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) @@ -416,8 +420,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) + llProtocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr1.WithPrefix(), + } + if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err) } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -690,8 +698,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -883,8 +895,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1065,8 +1081,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1240,8 +1260,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -1415,8 +1439,8 @@ func TestPacketQueing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1669,8 +1693,12 @@ func TestCallsToNeighborCache(t *testing.T) { if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } { @@ -1704,8 +1732,8 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b1aec5312..c824e27fa 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1127,11 +1127,12 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum } func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { + pkt.NICID = e.nic.ID() + // Raw socket packets are delivered based solely on the transport protocol // number. We only require that the packet be valid IPv6. e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) - pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1627,12 +1628,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. e.mu.Lock() defer e.mu.Unlock() - return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) + return e.addAndAcquirePermanentAddressLocked(addr, properties) } // addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but @@ -1642,8 +1643,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // solicited-node multicast group and start duplicate address detection. // // Precondition: e.mu must be write locked. -func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err != nil { return nil, err } @@ -2010,11 +2011,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen returns the IPv6 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index d2a23fd4f..0735ebb23 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -41,12 +41,12 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" - addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02") + addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03") // Tests use the extension header identifier values as uint8 instead of // header.IPv6ExtensionHeaderIdentifier. @@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { // addr2/addr3 yet as we haven't added those addresses. test.rxf(t, s, e, addr1, snmc, 0) - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err) } // Should receive a packet destined to the solicited node address of // addr2/addr3 now that we have added added addr2. test.rxf(t, s, e, addr1, snmc, 1) - if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr3.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err) } // Should still receive a packet destined to the solicited node address of @@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil { @@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Add a default route so that a return packet knows where to go. @@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) { func TestFragmentReassemblyTimeout(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") ) - if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") @@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err) } outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") @@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index bc9cf6999..3e5c438d3 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) { // Note, we will still expect to send a report for the global address's // solicited node address from the unspecified address as per RFC 3590 // section 4. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + globalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: globalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err) } reportCounter++ if got := reportStat.Value(); got != reportCounter { @@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) { // Adding a link-local address should send a report for its solicited node // address and globalMulticastAddr. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + linkLocalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err) } if dadResolutionTime != 0 { reportCounter++ @@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8837d66d8..938427420 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{ + PEB: stack.FirstPrimaryEndpoint, + ConfigType: configType, + Deprecated: deprecated, + }) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index f0186c64e..8297a7e10 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) @@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) @@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) @@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize @@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) { checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), )) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } checkDADMsg() diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 1b96b1fb8..26640b7ee 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addr := tcpip.AddressWithPrefix{ - Address: stackIPv4Addr, - PrefixLen: defaultIPv4PrefixLength, + addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackIPv4Addr, + PrefixLen: defaultIPv4PrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, clock diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 009cab643..05b879543 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -146,8 +146,12 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Add default route. diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index c10b19aa0..a72afadda 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -124,13 +124,13 @@ func main() { log.Fatalf("Bad IP address: %v", addrName) } - var addr tcpip.Address + var addrWithPrefix tcpip.AddressWithPrefix var proto tcpip.NetworkProtocolNumber if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) + addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix() proto = ipv4.ProtocolNumber } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) + addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix() proto = ipv6.ProtocolNumber } else { log.Fatalf("Unknown IP type: %v", addrName) @@ -176,11 +176,15 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) + subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address)))) if err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ae0bb4ace..7e4b5bf74 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We now promote the address. for i, s := range a.mu.primary { if s == addrState { - switch peb { + switch properties.PEB { case CanBePrimaryEndpoint: // The address is already in the primary address list. attemptAddToPrimary = false @@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address case NeverPrimaryEndpoint: a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } break } @@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // Acquire the address before returning it. addrState.mu.refs++ - addrState.mu.deprecated = deprecated - addrState.mu.configType = configType + addrState.mu.deprecated = properties.Deprecated + addrState.mu.configType = properties.ConfigType if attemptAddToPrimary { - switch peb { + switch properties.PEB { case NeverPrimaryEndpoint: case CanBePrimaryEndpoint: a.mu.primary = append(a.mu.primary, addrState) @@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address a.mu.primary[0] = addrState } default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } } @@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() - ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */) if err != nil { // addAndAcquireAddressLocked only returns an error if the address is // already assigned but we just checked above if the address exists so we // expect no error. - panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err)) } // From https://golang.org/doc/faq#nil_error: diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 140f146f6..c55f85743 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { } { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err) } // We don't need the address endpoint. ep.DecRef() diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index ccb69393b..c2f1f4798 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } @@ -384,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } // NIC 2 has the link address "b", and added the network address 2. @@ -397,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } nic, ok := s.nics[2] diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 4d5431da1..40b33b6b5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Should get the address immediately since we should not have performed @@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + }, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Make sure the address does not resolve before the resolution time has @@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet @@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) { // Attempting to add the address again should not fail if the address's // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } }) } @@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) { // Add addresses for each NIC. addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix1, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err) } addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix2, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err) } expectDADEvent(nicID2, addr2) addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix3, + } + if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err) } expectDADEvent(nicID3, addr3) @@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { continue } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: test.addrs[j].Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{} @@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr2, } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr} + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err) } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) @@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) { } if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a796942ab..ab7e1d859 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -514,7 +514,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return &tcpip.ErrUnknownProtocol{} @@ -525,7 +525,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return &tcpip.ErrNotSupported{} } - addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5cb342f78..c8ad93f29 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements NetworkProtocol.ParseAddresses. func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 113baaaae..31b3a554d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int const ( // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. + // endpoint for new connections with no local address. CanBePrimaryEndpoint PrimaryEndpointBehavior = iota // FirstPrimaryEndpoint indicates the endpoint should be the first @@ -332,6 +331,19 @@ const ( NeverPrimaryEndpoint ) +func (peb PrimaryEndpointBehavior) String() string { + switch peb { + case CanBePrimaryEndpoint: + return "CanBePrimaryEndpoint" + case FirstPrimaryEndpoint: + return "FirstPrimaryEndpoint" + case NeverPrimaryEndpoint: + return "NeverPrimaryEndpoint" + default: + panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb)) + } +} + // AddressConfigType is the method used to add an address. type AddressConfigType int @@ -351,6 +363,14 @@ const ( AddressConfigSlaacTemp ) +// AddressProperties contains additional properties that can be configured when +// adding an address. +type AddressProperties struct { + PEB PrimaryEndpointBehavior + ConfigType AddressConfigType + Deprecated bool +} + // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { @@ -457,7 +477,7 @@ type AddressableEndpoint interface { // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. @@ -685,9 +705,6 @@ type NetworkProtocol interface { // than this targeted at this protocol. MinimumPacketSize() int - // DefaultPrefixLen returns the protocol's default prefix length. - DefaultPrefixLen() int - // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cb741e540..98867a828 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -916,46 +916,9 @@ type NICStateFlags struct { Loopback bool } -// AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { - return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) -} - -// AddAddressWithPrefix is the same as AddAddress, but allows you to specify -// the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { - ap := tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: addr, - } - return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) -} - -// AddProtocolAddress adds a new network-layer protocol address to the -// specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { - return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) -} - -// AddAddressWithOptions is the same as AddAddress, but allows you to specify -// whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return &tcpip.ErrUnknownProtocol{} - } - return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb) -} - -// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows -// you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +// AddProtocolAddress adds an address to the specified NIC, possibly with extra +// properties. +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -964,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return &tcpip.ErrUnknownNICID{} } - return nic.addAddress(protocolAddress, peb) + return nic.addAddress(protocolAddress, properties) } // RemoveAddress removes an existing network-layer address from the specified diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3089c0ef4..cd4137794 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -234,10 +234,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { return f.packetCount[int(intfAddr)%len(f.packetCount)] } @@ -349,12 +345,26 @@ func TestNetworkReceive(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -517,8 +527,15 @@ func TestNetworkSend(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Make sure that the link-layer endpoint received the outbound packet. @@ -538,12 +555,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -551,12 +582,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -812,8 +857,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } ep2 := channel.New(1, defaultMTU, "") @@ -821,8 +873,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr2, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } // Set a route table that sends all packets with odd destination @@ -978,12 +1037,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -991,12 +1064,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -1058,8 +1145,15 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1108,8 +1202,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1242,8 +1343,15 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1270,8 +1378,8 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1310,8 +1418,8 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1453,8 +1561,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) @@ -1510,8 +1625,15 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { @@ -1633,8 +1755,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) + if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { @@ -1678,13 +1800,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { t.Fatalf("CreateNIC failed: %s", err) } nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) + if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) + if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err) } // Set the initial route table. @@ -1726,7 +1848,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1}, }, rt..., ) @@ -1808,8 +1930,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) } - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: anyAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { @@ -1886,22 +2015,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Add an address and in case of a primary one include a // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) + properties := stack.AddressProperties{PEB: behavior} if behavior == stack.CanBePrimaryEndpoint { protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, + Protocol: fakeNetNumber, + AddressWithPrefix: address.WithPrefix(), } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } // Remember the address/prefix. primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } } } @@ -1996,8 +2130,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { PrefixLen: tc.prefixLen, }, } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) + if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -2047,33 +2181,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } } -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ @@ -2084,96 +2191,43 @@ func TestAddProtocolAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - addrLenRange := []int{4, 16} behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) + configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp} + deprecatedRange := []bool{false, true} + wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange)) var addrGen addressGenerator for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) + for _, configType := range configTypeRange { + for _, deprecated := range deprecatedRange { + address := addrGen.next(addrLen) + properties := stack.AddressProperties{ + PEB: behavior, + ConfigType: configType, + Deprecated: deprecated, + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err) + } + wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, + }) } - expectedAddresses = append(expectedAddresses, protocolAddress) } } } gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) + verifyAddresses(t, wantAddresses, gotAddresses) } func TestCreateNICWithOptions(t *testing.T) { @@ -2290,8 +2344,15 @@ func TestNICStats(t *testing.T) { if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed: ", err) } - if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: nic.addr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err) } { @@ -2735,8 +2796,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // be returned by a call to GetMainNICAddress; // else, it should. const address1 = tcpip.Address("\x01") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + properties := stack.AddressProperties{PEB: pi} + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err) } addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) if err != nil { @@ -2785,16 +2854,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. const address3 = tcpip.Address("\x03") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) - + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address3, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err) } // Add back the address we removed earlier and // make sure the new peb was respected. // (The address should just be promoted now). - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: ps} + if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { @@ -3096,8 +3180,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: a.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -3203,8 +3291,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // The NIC should have joined addr1's solicited node multicast address. @@ -3359,8 +3451,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { PrefixLen: 128, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } // Address should be in the list of all addresses. @@ -3687,8 +3779,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) @@ -3750,8 +3842,8 @@ func TestResolveWith(t *testing.T) { PrefixLen: 24, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) @@ -3792,8 +3884,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -3881,8 +3980,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -3990,44 +4089,44 @@ func TestFindRouteWithForwarding(t *testing.T) { ) type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1AddrWithPrefix tcpip.AddressWithPrefix + nic2AddrWithPrefix tcpip.AddressWithPrefix + remoteAddr tcpip.Address } fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen}, + nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen}, + remoteAddr: remoteAddr, } globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: llAddr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: globalIPv6Addr1, } ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: llAddr1.WithPrefix(), + remoteAddr: llAddr2, } ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", } tests := []struct { @@ -4036,8 +4135,8 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg netCfg forwardingEnabled bool - addrNIC tcpip.NICID - localAddr tcpip.Address + addrNIC tcpip.NICID + localAddrWithPrefix tcpip.AddressWithPrefix findRouteErr tcpip.Error dependentOnForwarding bool @@ -4047,7 +4146,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4056,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4065,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4074,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4083,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4092,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4101,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4110,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4118,7 +4217,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4126,7 +4225,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4134,7 +4233,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4142,7 +4241,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4166,7 +4265,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on different NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4174,7 +4273,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4182,7 +4281,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4190,7 +4289,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4198,7 +4297,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4206,7 +4305,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4214,7 +4313,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4222,7 +4321,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4230,7 +4329,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4238,7 +4337,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4246,7 +4345,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4268,12 +4367,20 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) } - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic1AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic2AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { @@ -4282,20 +4389,20 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { return } - if r.LocalAddress() != test.localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr) + if r.LocalAddress() != test.localAddrWithPrefix.Address { + t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address) } if r.RemoteAddress() != test.netCfg.remoteAddr { t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr) @@ -4318,8 +4425,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if !ok { t.Fatal("packet not sent through ep2") } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address) } if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 45b09110d..cd3a8c25a 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -35,7 +35,7 @@ import ( const ( testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") testSrcAddrV4 = "\x0a\x00\x00\x01" testDstAddrV4 = "\x0a\x00\x00\x02" @@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: testDstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 839178809..655931715 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -357,8 +357,15 @@ func TestTransportReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -428,8 +435,15 @@ func TestTransportControlReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -497,8 +511,15 @@ func TestTransportSend(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 55683b4fb..a9ce148b9 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -19,7 +19,7 @@ // The starting point is the creation and configuration of a stack. A stack can // be created by calling the New() function of the tcpip/stack/stack package; // configuring a stack involves creating NICs (via calls to Stack.CreateNIC()), -// adding network addresses (via calls to Stack.AddAddress()), and +// adding network addresses (via calls to Stack.AddProtocolAddress()), and // setting a route table (via a call to Stack.SetRouteTable()). // // Once a stack is configured, endpoints can be created by calling diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 92fa6257d..6e1d4720d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) { addr: utils.RouterNIC2IPv6Addr, }, } { - if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err) } } diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index f9ab7d0af..28b49c6be 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -49,10 +49,10 @@ const ( nicName = "nic1" anotherNicName = "nic2" linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - srcAddrV4 = "\x0a\x00\x00\x01" - dstAddrV4 = "\x0a\x00\x00\x02" - srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01") + dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02") + srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") payloadSize = 20 ) @@ -66,8 +66,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: dstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -82,8 +86,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: dstAddrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -601,11 +609,19 @@ func TestIPTableWritePackets(t *testing.T) { if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: srcAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: srcAddrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } s.SetRouteTable([]tcpip.Route{ @@ -856,11 +872,19 @@ func TestForwardingHook(t *testing.T) { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -1037,22 +1061,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) { if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err) } e2 := channel.New(1, header.IPv6MinimumMTU, "") if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 27caa0c28..95ddd8ec3 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ @@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.incomingAddr, } - if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err) } // Set up endpoint through which we will attempt to forward packets. @@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.outgoingAddr, } - if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index b2008f0b2..f33223e79 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ { @@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + v4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err) + } + v6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err) } if err := s.CreateNIC(nicID2, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ipv4Loopback, + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: header.IPv6Loopback.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if test.forwarding { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 2d0a6e6a7..7753e7d6e 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err) } // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote @@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } var wq waiter.Queue @@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } // Set the route table so that UDP can find a NIC that is diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index ac3c703d4..422eb8408 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) { // request/reply packets. icmpDataOffset = 8 ) - ipv4Loopback := testutil.MustParse4("127.0.0.1") + ipv4Loopback := tcpip.AddressWithPrefix{ + Address: testutil.MustParse4("127.0.0.1"), + PrefixLen: 8, + } channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { @@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) { transProto tcpip.TransportProtocolNumber netProto tcpip.NetworkProtocolNumber linkEndpoint func() stack.LinkEndpoint - localAddr tcpip.Address + localAddr tcpip.AddressWithPrefix icmpBuf func(*testing.T) buffer.View expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) @@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, - localAddr: header.IPv6Loopback, + localAddr: header.IPv6Loopback.WithPrefix(), icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, @@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv4Addr.Address, + localAddr: utils.Ipv4Addr, icmpBuf: ipv4ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv6Addr.Address, + localAddr: utils.Ipv6Addr, icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + if len(test.localAddr.Address) != 0 { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.netProto, + AddressWithPrefix: test.localAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) { } defer ep.Close() - connAddr := tcpip.FullAddress{Addr: test.localAddr} + connAddr := tcpip.FullAddress{Addr: test.localAddr.Address} if err := ep.Connect(connAddr); err != test.expectedConnectErr { t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) } @@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) { if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { t.Errorf("received data mismatch (-want +got):\n%s", diff) } - if rr.RemoteAddr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + if rr.RemoteAddr.Addr != test.localAddr.Address { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address) } test.checkLinkEndpoint(t, e) @@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) { } if subTest.addAddress { - if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err) } - if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 2e6ae55ea..947bcc7b1 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -231,29 +231,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err) } host1Stack.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go index cc950cbde..729f50e9a 100644 --- a/pkg/tcpip/transport/icmp/icmp_test.go +++ b/pkg/tcpip/transport/icmp/icmp_test.go @@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.AddRoute(tcpip.Route{ diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index d10e3f13a..b1edce39b 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -9,6 +9,7 @@ go_library( "endpoint_state.go", ], visibility = [ + "//pkg/tcpip/transport/raw:__pkg__", "//pkg/tcpip/transport/udp:__pkg__", ], deps = [ @@ -32,6 +33,7 @@ go_test( "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index c5b575e1c..3cb821475 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -18,7 +18,6 @@ package network import ( "fmt" - "sync/atomic" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,30 +37,41 @@ type Endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - // state holds a transport.DatagramBasedEndpointState. - // - // state must be read from/written to atomically. - state uint32 - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + state transport.DatagramEndpointState + // +checklocks:mu + wasBound bool + // +checklocks:mu info stack.TransportEndpointInfo // owner is the owner of transmitted packets. - owner tcpip.PacketOwner - writeShutdown bool - effectiveNetProto tcpip.NetworkProtocolNumber - connectedRoute *stack.Route `state:"manual"` + // + // +checklocks:mu + owner tcpip.PacketOwner + // +checklocks:mu + writeShutdown bool + // +checklocks:mu + effectiveNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu + connectedRoute *stack.Route `state:"manual"` + // +checklocks:mu multicastMemberships map[multicastMembership]struct{} // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu ttl uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastTTL uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastAddr tcpip.Address // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastNICID tcpip.NICID - ipv4TOS uint8 - ipv6TClass uint8 + // +checklocks:mu + ipv4TOS uint8 + // +checklocks:mu + ipv6TClass uint8 } // +stateify savable @@ -72,8 +82,11 @@ type multicastMembership struct { // Init initializes the endpoint. func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { - if e.multicastMemberships != nil { - panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships)) + e.mu.Lock() + memberships := e.multicastMemberships + e.mu.Unlock() + if memberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships)) } switch netProto { @@ -88,8 +101,7 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr netProto: netProto, transProto: transProto, - state: uint32(transport.DatagramEndpointStateInitial), - + state: transport.DatagramEndpointStateInitial, info: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: transProto, @@ -106,14 +118,11 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { return e.netProto } -// setState sets the state of the endpoint. -func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { - atomic.StoreUint32(&e.state, uint32(state)) -} - // State returns the state of the endpoint. func (e *Endpoint) State() transport.DatagramEndpointState { - return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) + e.mu.RLock() + defer e.mu.RUnlock() + return e.state } // Close cleans the endpoint's resources and leaves the endpoint in a closed @@ -122,7 +131,7 @@ func (e *Endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.State() == transport.DatagramEndpointStateClosed { + if e.state == transport.DatagramEndpointStateClosed { return } @@ -136,7 +145,7 @@ func (e *Endpoint) Close() { e.connectedRoute = nil } - e.setEndpointState(transport.DatagramEndpointStateClosed) + e.state = transport.DatagramEndpointStateClosed } // SetOwner sets the owner of transmitted packets. @@ -217,7 +226,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext return WriteContext{}, &tcpip.ErrInvalidOptionValue{} } - if e.State() == transport.DatagramEndpointStateClosed { + if e.state == transport.DatagramEndpointStateClosed { return WriteContext{}, &tcpip.ErrInvalidEndpointState{} } @@ -229,7 +238,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return WriteContext{}, &tcpip.ErrDestinationRequired{} } @@ -248,13 +257,16 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext nicID = e.info.BindNICID } + if nicID == 0 { + nicID = e.info.RegisterNICID + } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) + dst, netProto, err := e.checkV4MappedRLocked(*opts.To) if err != nil { return WriteContext{}, err } - route, _, err = e.connectRoute(nicID, dst, netProto) + route, _, err = e.connectRouteRLocked(nicID, dst, netProto) if err != nil { return WriteContext{}, err } @@ -289,29 +301,32 @@ func (e *Endpoint) Disconnect() { e.mu.Lock() defer e.mu.Unlock() - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return } // Exclude ephemerally bound endpoints. - if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" { + if e.wasBound { e.info.ID = stack.TransportEndpointID{ - LocalAddress: e.info.ID.LocalAddress, + LocalAddress: e.info.BindAddr, } - e.setEndpointState(transport.DatagramEndpointStateBound) + e.state = transport.DatagramEndpointStateBound } else { e.info.ID = stack.TransportEndpointID{} - e.setEndpointState(transport.DatagramEndpointStateInitial) + e.state = transport.DatagramEndpointStateInitial } e.connectedRoute.Release() e.connectedRoute = nil } -// connectRoute establishes a route to the specified interface or the +// connectRouteRLocked establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { +// +// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. +// +checklocks:e.mu +func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.info.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -356,7 +371,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. defer e.mu.Unlock() nicID := addr.NIC - switch e.State() { + switch e.state { case transport.DatagramEndpointStateInitial: case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: if e.info.BindNICID == 0 { @@ -372,12 +387,12 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4MappedRLocked(addr) if err != nil { return err } - r, nicID, err := e.connectRoute(nicID, addr, netProto) + r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto) if err != nil { return err } @@ -386,7 +401,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. LocalAddress: e.info.ID.LocalAddress, RemoteAddress: r.RemoteAddress(), } - if e.State() == transport.DatagramEndpointStateInitial { + if e.state == transport.DatagramEndpointStateInitial { id.LocalAddress = r.LocalAddress() } @@ -402,7 +417,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.info.ID = id e.info.RegisterNICID = nicID e.effectiveNetProto = netProto - e.setEndpointState(transport.DatagramEndpointStateConnected) + e.state = transport.DatagramEndpointStateConnected return nil } @@ -411,7 +426,7 @@ func (e *Endpoint) Shutdown() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - switch state := e.State(); state { + switch state := e.state; state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: @@ -422,9 +437,12 @@ func (e *Endpoint) Shutdown() tcpip.Error { } } -// checkV4MappedLocked determines the effective network protocol and converts +// checkV4MappedRLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { +// +// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. +// +checklocks:e.mu +func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -456,11 +474,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto // Don't allow binding once endpoint is not in the initial state // anymore. - if e.State() != transport.DatagramEndpointStateInitial { + if e.state != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4MappedRLocked(addr) if err != nil { return err } @@ -477,24 +495,33 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto return err } + e.wasBound = true + e.info.ID = stack.TransportEndpointID{ LocalAddress: addr.Addr, } - e.info.BindNICID = nicID + e.info.BindNICID = addr.NIC e.info.RegisterNICID = nicID e.info.BindAddr = addr.Addr e.effectiveNetProto = netProto - e.setEndpointState(transport.DatagramEndpointStateBound) + e.state = transport.DatagramEndpointStateBound return nil } +// WasBound returns true iff the endpoint was ever bound. +func (e *Endpoint) WasBound() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.wasBound +} + // GetLocalAddress returns the address that the endpoint is bound to. func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { e.mu.RLock() defer e.mu.RUnlock() addr := e.info.BindAddr - if e.State() == transport.DatagramEndpointStateConnected { + if e.state == transport.DatagramEndpointStateConnected { addr = e.connectedRoute.LocalAddress() } @@ -509,7 +536,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { e.mu.RLock() defer e.mu.RUnlock() - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return tcpip.FullAddress{}, false } @@ -597,7 +624,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) + fa, netProto, err := e.checkV4MappedRLocked(fa) if err != nil { return err } diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go index 858007156..173197512 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_state.go +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -35,7 +35,7 @@ func (e *Endpoint) Resume(s *stack.Stack) { } } - switch state := e.State(); state { + switch e.state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: case transport.DatagramEndpointStateBound: if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { @@ -51,6 +51,6 @@ func (e *Endpoint) Resume(s *stack.Stack) { panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) } default: - panic(fmt.Sprintf("unhandled state = %s", state)) + panic(fmt.Sprintf("unhandled state = %s", e.state)) } } diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go index 2c43eb66a..f263a9ea2 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_test.go +++ b/pkg/tcpip/transport/internal/network/endpoint_test.go @@ -15,6 +15,7 @@ package network_test import ( + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -24,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,17 +35,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -func TestEndpointStateTransitions(t *testing.T) { - const ( - nicID = 1 - ) +var ( + ipv4NICAddr = testutil.MustParse4("1.2.3.4") + ipv6NICAddr = testutil.MustParse6("a::1") + ipv4RemoteAddr = testutil.MustParse4("6.7.8.9") + ipv6RemoteAddr = testutil.MustParse6("b::1") +) - var ( - ipv4NICAddr = testutil.MustParse4("1.2.3.4") - ipv6NICAddr = testutil.MustParse6("a::1") - ipv4RemoteAddr = testutil.MustParse4("6.7.8.9") - ipv6RemoteAddr = testutil.MustParse6("b::1") - ) +func TestEndpointStateTransitions(t *testing.T) { + const nicID = 1 data := buffer.View([]byte{1, 2, 4, 5}) v4Checker := func(t *testing.T, b buffer.View) { @@ -124,11 +124,20 @@ func TestEndpointStateTransitions(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), + } + + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -139,6 +148,7 @@ func TestEndpointStateTransitions(t *testing.T) { var ops tcpip.SocketOptions var ep network.Endpoint ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + defer ep.Close() if state := ep.State(); state != transport.DatagramEndpointStateInitial { t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial) } @@ -207,3 +217,102 @@ func TestEndpointStateTransitions(t *testing.T) { }) } } + +func TestBindNICID(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + bindAddr tcpip.Address + unicast bool + }{ + { + name: "IPv4 multicast", + netProto: ipv4.ProtocolNumber, + bindAddr: header.IPv4AllSystems, + unicast: false, + }, + { + name: "IPv6 multicast", + netProto: ipv6.ProtocolNumber, + bindAddr: header.IPv6AllNodesMulticastAddress, + unicast: false, + }, + { + name: "IPv4 unicast", + netProto: ipv4.ProtocolNumber, + bindAddr: ipv4NICAddr, + unicast: true, + }, + { + name: "IPv6 unicast", + netProto: ipv6.ProtocolNumber, + bindAddr: ipv6NICAddr, + unicast: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, testBindNICID := range []tcpip.NICID{0, nicID} { + t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err) + } + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) + } + + var ops tcpip.SocketOptions + var ep network.Endpoint + ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + defer ep.Close() + if ep.WasBound() { + t.Fatal("got ep.WasBound() = true, want = false") + } + wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber} + if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" { + t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff) + } + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + if !ep.WasBound() { + t.Error("got ep.WasBound() = false, want = true") + } + wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr} + wantInfo.BindAddr = bindAddr.Addr + wantInfo.BindNICID = bindAddr.NIC + if test.unicast { + wantInfo.RegisterNICID = nicID + } else { + wantInfo.RegisterNICID = bindAddr.NIC + } + if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" { + t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0554d2f4a..2c9786175 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -68,31 +68,31 @@ type endpoint struct { netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue cooked bool - - // The following fields are used to manage the receive queue and are - // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats `state:"nosave"` + + // The following fields are used to manage the receive queue. + rcvMu sync.Mutex `state:"nosave"` + // +checklocks:rcvMu + rcvList packetList + // +checklocks:rcvMu rcvBufSize int - rcvClosed bool - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + // +checklocks:rcvMu + rcvClosed bool + // +checklocks:rcvMu + rcvDisabled bool + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + closed bool + // +checklocks:mu + bound bool + // +checklocks:mu boundNIC tcpip.NICID - // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` - lastError tcpip.Error - - // ops is used to get socket level options. - ops tcpip.SocketOptions - - // frozen indicates if the packets should be delivered to the endpoint - // during restore. - frozen bool + // +checklocks:lastErrorMu + lastError tcpip.Error } // NewEndpoint returns a new packet endpoint. @@ -414,7 +414,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } rcvBufSize := ep.ops.GetReceiveBufferSize() - if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { + if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -491,18 +491,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } - -// freeze prevents any more packets from being delivered to the endpoint. -func (ep *endpoint) freeze() { - ep.mu.Lock() - ep.frozen = true - ep.mu.Unlock() -} - -// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows -// new packets to be delivered again. -func (ep *endpoint) thaw() { - ep.mu.Lock() - ep.frozen = false - ep.mu.Unlock() -} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 5c688d286..d2768db7b 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -44,12 +44,16 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - ep.freeze() + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + ep.rcvDisabled = true } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - ep.thaw() + ep.mu.Lock() + defer ep.mu.Unlock() + ep.stack = stack.StackFromEnv ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) @@ -57,4 +61,8 @@ func (ep *endpoint) afterLoad() { if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { panic(err) } + + ep.rcvMu.Lock() + ep.rcvDisabled = false + ep.rcvMu.Unlock() } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 2eab09088..b7e97e218 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/packet", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 3bf6c0a8f..3040a445b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,6 +26,7 @@ package raw import ( + "fmt" "io" "time" @@ -34,6 +35,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -57,15 +60,19 @@ type rawPacket struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue associated bool + net network.Endpoint + stats tcpip.TransportEndpointStats `state:"nosave"` + ops tcpip.SocketOptions + // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` @@ -74,20 +81,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - closed bool - connected bool - bound bool - // route is the route to a remote network endpoint. It is set via - // Connect(), and is valid only when conneted is true. - route *stack.Route `state:"manual"` - stats tcpip.TransportEndpointStats `state:"nosave"` - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - + mu sync.RWMutex `state:"nosave"` // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool @@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { - if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { - return nil, &tcpip.ErrUnknownProtocol{} - } - e := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, associated: associated, } @@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) e.ops.SetReceiveBufferSize(32*1024, false /* notify */) + e.net.Init(s, netProto, transProto, &e.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -137,7 +125,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { return nil, err } @@ -154,11 +142,17 @@ func (e *endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.closed || !e.associated { + if e.net.State() == transport.DatagramEndpointStateClosed { return } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) + e.net.Close() + + if !e.associated { + return + } + + e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e) e.rcvMu.Lock() defer e.rcvMu.Unlock() @@ -170,15 +164,6 @@ func (e *endpoint) Close() { e.rcvList.Remove(e.rcvList.Front()) } - e.connected = false - - if e.route != nil { - e.route.Release() - e.route = nil - } - - e.closed = true - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -186,9 +171,7 @@ func (e *endpoint) Close() { func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.mu.Lock() - defer e.mu.Unlock() - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + netProto := e.net.NetProto() // We can create, but not write to, unassociated IPv6 endpoints. - if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + if !e.associated && netProto == header.IPv6ProtocolNumber { return 0, &tcpip.ErrInvalidOptionValue{} } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. - if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { return 0, &tcpip.ErrInvalidOptionValue{} } } @@ -269,79 +253,26 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - if e.closed { - return nil, nil, nil, &tcpip.ErrInvalidEndpointState{} - } - - // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. - payloadBytes := make([]byte, p.Len()) - if _, err := io.ReadFull(p, payloadBytes); err != nil { - return nil, nil, nil, &tcpip.ErrBadBuffer{} - } - - // Did the user caller provide a destination? If not, use the connected - // destination. - if opts.To == nil { - // If the user doesn't specify a destination, they should have - // connected to another address. - if !e.connected { - return nil, nil, nil, &tcpip.ErrDestinationRequired{} - } - - e.route.Acquire() - - return payloadBytes, e.route, e.owner, nil - } - - // The caller provided a destination. Reject destination address if it - // goes through a different NIC than the endpoint was bound to. - nic := opts.To.NIC - if e.bound && nic != 0 && nic != e.BindNICID { - return nil, nil, nil, &tcpip.ErrNoRoute{} - } - - // Find the route to the destination. If BindAddress is 0, - // FindRoute will choose an appropriate source address. - route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) - if err != nil { - return nil, nil, nil, err - } - - return payloadBytes, route, e.owner, nil - }() + ctx, err := e.net.AcquireContextForWrite(opts) if err != nil { return 0, err } - defer route.Release() + + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} + } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()), + ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength), Data: buffer.View(payloadBytes).ToVectorisedView(), }) - pkt.Owner = owner - - if e.ops.GetHeaderIncluded() { - if err := route.WriteHeaderIncludedPacket(pkt); err != nil { - return 0, err - } - return int64(len(payloadBytes)), nil - } - if err := route.WritePacket(stack.NetworkHeaderParams{ - Protocol: e.TransProto, - TTL: route.DefaultTTL(), - TOS: stack.DefaultTOS, - }, pkt); err != nil { + if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil { return 0, err } + return int64(len(payloadBytes)), nil } @@ -352,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + netProto := e.net.NetProto() + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. - if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { return &tcpip.ErrAddressFamilyNotSupported{} } - e.mu.Lock() - defer e.mu.Unlock() - - if e.closed { - return &tcpip.ErrInvalidEndpointState{} - } - - nic := addr.NIC - if e.bound { - if e.BindNICID == 0 { - // If we're bound, but not to a specific NIC, the NIC - // in addr will be used. Nothing to do here. - } else if addr.NIC == 0 { - // If we're bound to a specific NIC, but addr doesn't - // specify a NIC, use the bound NIC. - nic = e.BindNICID - } else if addr.NIC != e.BindNICID { - // We're bound and addr specifies a NIC. They must be - // the same. - return &tcpip.ErrInvalidEndpointState{} - } - } - - // Find a route to the destination. - route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) - if err != nil { - return err - } - - if e.associated { - // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { - route.Release() - return err + return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error { + if e.associated { + // Re-register the endpoint with the appropriate NIC. + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { + return err + } + e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e) } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) - e.RegisterNICID = nic - } - if e.route != nil { - // If the endpoint was previously connected then release any previous route. - e.route.Release() - } - e.route = route - e.connected = true - - return nil + return nil + }) } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - if !e.connected { + if e.net.State() != transport.DatagramEndpointStateConnected { return &tcpip.ErrNotConnected{} } return nil @@ -429,46 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi // Bind implements tcpip.Endpoint.Bind. func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - // If a local address was specified, verify that it's valid. - if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} - } + return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error { + if !e.associated { + return nil + } - if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { return err } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) - e.RegisterNICID = addr.NIC - e.BindNICID = addr.NIC - } - - e.BindAddr = addr.Addr - e.bound = true - - return nil + e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e) + return nil + }) } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - addr := e.BindAddr - if e.connected { - addr = e.route.LocalAddress() - } - - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: addr, - // Linux returns the protocol in the port field. - Port: uint16(e.TransProto), - }, nil + a := e.net.GetLocalAddress() + // Linux returns the protocol in the port field. + a.Port = uint16(e.transProto) + return a, nil } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. @@ -501,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil default: - return &tcpip.ErrUnknownProtocolOption{} + return e.net.SetSockOpt(opt) } } -func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { + return e.net.SetSockOptInt(opt, v) } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -528,103 +402,108 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { return v, nil default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - e.mu.RLock() - e.rcvMu.Lock() + notifyReadableEvents := func() bool { + e.mu.RLock() + defer e.mu.RUnlock() + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return false + } - // Drop the packet if our buffer is currently full or if this is an unassociated - // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only - // See: https://man7.org/linux/man-pages/man7/raw.7.html - // - // An IPPROTO_RAW socket is send only. If you really want to receive - // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. - // Note that packet sockets don't reassemble IP fragments, unlike raw - // sockets. - if e.rcvClosed || !e.associated { - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stack.Stats().DroppedPackets.Increment() - e.stats.ReceiveErrors.ClosedReceiver.Increment() - return - } + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return false + } - rcvBufSize := e.ops.GetReceiveBufferSize() - if e.frozen || e.rcvBufSize >= int(rcvBufSize) { - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stack.Stats().DroppedPackets.Increment() - e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() - return - } + srcAddr := pkt.Network().SourceAddress() + info := e.net.Info() - remoteAddr := pkt.Network().SourceAddress() + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: + // If connected, only accept packets from the remote address we + // connected to. + if info.ID.RemoteAddress != srcAddr { + return false + } - if e.bound { - // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != pkt.NICID { - e.rcvMu.Unlock() - e.mu.RUnlock() - return - } - // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != remoteAddr { - e.rcvMu.Unlock() - e.mu.RUnlock() - return + // Connected sockets may also have been bound to a specific + // address/NIC. + fallthrough + case transport.DatagramEndpointStateBound: + // If bound to a NIC, only accept data for that NIC. + if info.BindNICID != 0 && info.BindNICID != pkt.NICID { + return false + } + + // If bound to an address, only accept data for that address. + if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() { + return false + } + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - } - // If connected, only accept packets from the remote address we - // connected to. - if e.connected && e.route.RemoteAddress() != remoteAddr { - e.rcvMu.Unlock() - e.mu.RUnlock() - return - } + wasEmpty := e.rcvBufSize == 0 - wasEmpty := e.rcvBufSize == 0 + // Push new packet into receive list and increment the buffer size. + packet := &rawPacket{ + senderAddr: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: srcAddr, + }, + } - // Push new packet into receive list and increment the buffer size. - packet := &rawPacket{ - senderAddr: tcpip.FullAddress{ - NIC: pkt.NICID, - Addr: remoteAddr, - }, - } + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + // + // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports + // overlapping slices. + var combinedVV buffer.VectorisedView + if info.NetProto == header.IPv4ProtocolNumber { + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() + } + combinedVV.Append(pkt.Data().ExtractVV()) + packet.data = combinedVV + packet.receivedAt = e.stack.Clock().Now() - // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. - // We copy headers' underlying bytes because pkt.*Header may point to - // the middle of a slice, and another struct may point to the "outer" - // slice. Save/restore doesn't support overlapping slices and will fail. - // - // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports - // overlapping slices. - var combinedVV buffer.VectorisedView - if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - headers := make(buffer.View, 0, len(network)+len(transport)) - headers = append(headers, network...) - headers = append(headers, transport...) - combinedVV = headers.ToVectorisedView() - } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() - } - combinedVV.Append(pkt.Data().ExtractVV()) - packet.data = combinedVV - packet.receivedAt = e.stack.Clock().Now() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.stats.PacketsReceived.Increment() - e.rcvList.PushBack(packet) - e.rcvBufSize += packet.data.Size() - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stats.PacketsReceived.Increment() - // Notify waiters that there's data to be read. - if wasEmpty { + // Notify waiters that there is data to be read now. + return wasEmpty + }() + + if notifyReadableEvents { e.waiterQueue.Notify(waiter.ReadableEvents) } } @@ -636,10 +515,7 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + ret := e.net.Info() return &ret } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 39669b445..e74713064 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,6 +15,7 @@ package raw import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.net.Resume(s) + e.thaw() e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // If the endpoint is connected, re-connect. - if e.connected { - var err tcpip.Error - // TODO(gvisor.dev/issue/4906): Properly restore the route with the right - // remote address. We used to pass e.remote.RemoteAddress which was - // effectively the empty address but since moving e.route to hold a pointer - // to a route instead of the route by value, we pass the empty address - // directly. Obviously this was always wrong since we should provide the - // remote address we were connected to, to properly restore the route. - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false) - if err != nil { - panic(err) - } - } - - // If the endpoint is bound, re-bind. - if e.bound { - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - if e.associated { - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { - panic(err) + netProto := e.net.NetProto() + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { + panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err)) } } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index bc8708a5b..58817371e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1382,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) { if err := s.CreateNIC(id, ep); err != nil { t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ {Destination: header.IPv4EmptySubnet, NIC: id}, @@ -2145,12 +2149,15 @@ func TestSmallReceiveBufferReadiness(t *testing.T) { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) } - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x7f\x00\x00\x01"), - PrefixLen: 8, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x7f\x00\x00\x01"), + PrefixLen: 8, + }, } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err) } { @@ -4954,13 +4961,17 @@ func makeStack() (*stack.Stack, tcpip.Error) { } for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address + number tcpip.NetworkProtocolNumber + addrWithPrefix tcpip.AddressWithPrefix }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, + {ipv4.ProtocolNumber, context.StackAddrWithPrefix}, + {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix}, } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ct.number, + AddressWithPrefix: ct.addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { return nil, err } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e55a7a32..88bb99354 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -243,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv4.ProtocolNumber, AddressWithPrefix: StackAddrWithPrefix, } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, @@ -257,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv6.ProtocolNumber, AddressWithPrefix: StackV6AddrWithPrefix, } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f171a16f8..4255457f9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -547,7 +547,7 @@ func (e *endpoint) Disconnect() tcpip.Error { info := e.net.Info() info.ID.LocalPort = e.localPort info.ID.RemotePort = e.remotePort - if info.BindNICID != 0 || info.ID.LocalAddress == "" { + if e.net.WasBound() { var err tcpip.Error id = stack.TransportEndpointID{ LocalPort: info.ID.LocalPort, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 554ce1de4..2d15830a7 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -323,12 +323,20 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) } s.SetRouteTable([]tcpip.Route{ @@ -2504,8 +2512,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) |