diff options
Diffstat (limited to 'pkg/tcpip')
103 files changed, 6742 insertions, 2517 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index dbe4506cc..b98de54c5 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -25,6 +25,7 @@ go_library( "stdclock.go", "stdclock_state.go", "tcpip.go", + "tcpip_state.go", "timer.go", ], visibility = ["//visibility:public"], diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 010e2e833..1f2bcaf65 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -19,6 +19,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net" "time" @@ -471,9 +472,9 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc return DialContextTCP(context.Background(), s, addr, network) } -// DialContextTCP creates a new TCPConn connected to the specified address -// with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { +// DialTCPWithBind creates a new TCPConn connected to the specified +// remoteAddress with its local address bound to localAddr. +func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -494,7 +495,14 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, default: } - err = ep.Connect(addr) + // Bind before connect if requested. + if localAddr != (tcpip.FullAddress{}) { + if err = ep.Bind(localAddr); err != nil { + return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err) + } + } + + err = ep.Connect(remoteAddr) if _, ok := err.(*tcpip.ErrConnectStarted); ok { select { case <-ctx.Done(): @@ -510,7 +518,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, return nil, &net.OpError{ Op: "connect", Net: "tcp", - Addr: fullToTCPAddr(addr), + Addr: fullToTCPAddr(remoteAddr), Err: errors.New(err.String()), } } @@ -518,6 +526,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, return NewTCPConn(&wq, ep), nil } +// DialContextTCP creates a new TCPConn connected to the specified address +// with the option of adding cancellation and timeouts. +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { + return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network) +} + // A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements // net.Conn and net.PacketConn. type UDPConn struct { diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 48b24692b..dcc9fff17 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } done := make(chan struct{}) @@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } done := make(chan struct{}) fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { @@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { @@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) if err != nil { @@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) + return nil, nil, nil, fmt.Errorf("NewListener: %w", err) } c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) if err != nil { l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) + return nil, nil, nil, fmt.Errorf("DialTCP: %w", err) } c2, err = l.Accept() if err != nil { l.Close() c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) + return nil, nil, nil, fmt.Errorf("l.Accept: %w", err) } stop = func() { @@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { if err := l.Close(); err != nil { stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) + return nil, nil, nil, fmt.Errorf("l.Close(): %w", err) } return c1, c2, stop, nil @@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { time.Sleep(time.Second) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f34bf8dd..24c2c3e6b 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field +// in ControlMessages. +func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPv6PacketInfo { + t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo) + } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" { + t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff) + } + } +} + // ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress // field in ControlMessages. func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index dcc549c7b..7baaf0d17 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -208,6 +208,15 @@ var IPv4EmptySubnet = func() tcpip.Subnet { return subnet }() +// IPv4LoopbackSubnet is the loopback subnet for IPv4. +var IPv4LoopbackSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(tcpip.Address("\x7f\x00\x00\x00"), tcpip.AddressMask("\xff\x00\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + // IPVersion returns the version of IP used in the given packet. It returns -1 // if the packet is not large enough to contain the version field. func IPVersion(b []byte) int { diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 1c913b5e1..80a9ad6be 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -110,6 +110,16 @@ traverseExtensions: switch extHdr := extHdr.(type) { case header.IPv6FragmentExtHdr: + if extHdr.IsAtomic() { + // This fragment extension header indicates that this packet is an + // atomic fragment. An atomic fragment is a fragment that contains + // all the data required to reassemble a full packet. As per RFC 6946, + // atomic fragments must not interfere with "normal" fragmented traffic + // so we skip processing the fragment instead of feeding it through the + // reassembly process below. + continue + } + if fragID == 0 && fragOffset == 0 && !fragMore { fragID = extHdr.ID() fragOffset = extHdr.FragmentOffset() @@ -175,3 +185,61 @@ func TCP(pkt *stack.PacketBuffer) bool { pkt.TransportProtocolNumber = header.TCPProtocolNumber return ok } + +// ICMPv4 populates the packet buffer's transport header with an ICMPv4 header, +// if present. +// +// Returns true if an ICMPv4 header was successfully parsed. +func ICMPv4(pkt *stack.PacketBuffer) bool { + if _, ok := pkt.TransportHeader().Consume(header.ICMPv4MinimumSize); ok { + pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + return true + } + return false +} + +// ICMPv6 populates the packet buffer's transport header with an ICMPv4 header, +// if present. +// +// Returns true if an ICMPv6 header was successfully parsed. +func ICMPv6(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize) + if !ok { + return false + } + + h := header.ICMPv6(hdr) + switch h.Type() { + case header.ICMPv6RouterSolicit, + header.ICMPv6RouterAdvert, + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborAdvert, + header.ICMPv6RedirectMsg: + size := pkt.Data().Size() + if _, ok := pkt.TransportHeader().Consume(size); !ok { + panic(fmt.Sprintf("expected to consume the full data of size = %d bytes into transport header", size)) + } + case header.ICMPv6MulticastListenerQuery, + header.ICMPv6MulticastListenerReport, + header.ICMPv6MulticastListenerDone: + size := header.ICMPv6HeaderSize + header.MLDMinimumSize + if _, ok := pkt.TransportHeader().Consume(size); !ok { + return false + } + case header.ICMPv6DstUnreachable, + header.ICMPv6PacketTooBig, + header.ICMPv6TimeExceeded, + header.ICMPv6ParamProblem, + header.ICMPv6EchoRequest, + header.ICMPv6EchoReply: + fallthrough + default: + if _, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize); !ok { + // Checked above if the packet buffer holds at least the minimum size for + // an ICMPv6 packet. + panic(fmt.Sprintf("expected to consume %d bytes", header.ICMPv6MinimumSize)) + } + } + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + return true +} diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 3ed0aa3fe..c67ca98ea 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -123,4 +123,6 @@ func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber } // WriteRawPacket implements stack.LinkEndpoint. -func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } +func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.WritePacket(stack.RouteInfo{}, 0, pkt) +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 87a0b9a62..e53789d92 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -152,10 +152,22 @@ type PollEvent struct { // no data is available, it will block in a poll() syscall until the file // descriptor becomes readable. func BlockingRead(fd int, b []byte) (int, tcpip.Error) { + n, err := BlockingReadUntranslated(fd, b) + if err != 0 { + return n, TranslateErrno(err) + } + return n, nil +} + +// BlockingReadUntranslated reads from a file descriptor that is set up as +// non-blocking. If no data is available, it will block in a poll() syscall +// until the file descriptor becomes readable. It returns the raw unix.Errno +// value returned by the underlying syscalls. +func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) { for { n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) if e == 0 { - return int(n), nil + return int(n), 0 } event := PollEvent{ @@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) { _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != unix.EINTR { - return 0, TranslateErrno(e) + return 0, e } } } diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 4215ee852..af755473c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -5,19 +5,27 @@ package(licenses = ["notice"]) go_library( name = "sharedmem", srcs = [ + "queuepair.go", "rx.go", + "server_rx.go", + "server_tx.go", "sharedmem.go", + "sharedmem_server.go", "sharedmem_unsafe.go", "tx.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/cleanup", + "//pkg/eventfd", "//pkg/log", + "//pkg/memutil", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", + "//pkg/tcpip/link/sharedmem/pipe", "//pkg/tcpip/link/sharedmem/queue", "//pkg/tcpip/stack", "@org_golang_x_sys//unix:go_default_library", @@ -26,9 +34,7 @@ go_library( go_test( name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], + srcs = ["sharedmem_test.go"], library = ":sharedmem", deps = [ "//pkg/sync", @@ -41,3 +47,22 @@ go_test( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "sharedmem_server_test", + size = "small", + srcs = ["sharedmem_server_test.go"], + deps = [ + ":sharedmem", + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/header", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go index 696e6c9e5..a78826ebc 100644 --- a/pkg/tcpip/link/sharedmem/queue/rx.go +++ b/pkg/tcpip/link/sharedmem/queue/rx.go @@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool { } r.tx.Flush() - return true } @@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool { func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) { for { outBufs := bufs - // Pull the next descriptor from the rx pipe. b := r.rx.Pull() if b == nil { diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go new file mode 100644 index 000000000..b12647fdd --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queuepair.go @@ -0,0 +1,199 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "fmt" + "io/ioutil" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" +) + +const ( + // defaultQueueDataSize is the size of the shared memory data region that + // holds the scatter/gather buffers. + defaultQueueDataSize = 1 << 20 // 1MiB + + // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors. + // + // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU) + // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data + // area. Which means the pipe needs to be big enough to hold 819 + // descriptors. + // + // Each descriptor is approximately 8 (slot descriptor in pipe) + + // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is + // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.) + // + // Which means we need approximately 36*819 ~ 29 KiB to store all packet + // descriptors. We could go with a 32 KiB pipe but to give it some slack in + // how the upper layer may make use of the scatter gather buffers we double + // this to hold enough descriptors. + defaultQueuePipeSize = 64 << 10 // 64KiB + + // defaultSharedDataSize is the size of the sharedData region used to + // enable/disable notifications. + defaultSharedDataSize = 4 << 10 // 4KiB +) + +// A QueuePair represents a pair of TX/RX queues. +type QueuePair struct { + // txCfg is the QueueConfig to be used for transmit queue. + txCfg QueueConfig + + // rxCfg is the QueueConfig to be used for receive queue. + rxCfg QueueConfig +} + +// NewQueuePair creates a shared memory QueuePair. +func NewQueuePair() (*QueuePair, error) { + txCfg, err := createQueueFDs(queueSizes{ + dataSize: defaultQueueDataSize, + txPipeSize: defaultQueuePipeSize, + rxPipeSize: defaultQueuePipeSize, + sharedDataSize: defaultSharedDataSize, + }) + + if err != nil { + return nil, fmt.Errorf("failed to create tx queue: %s", err) + } + + rxCfg, err := createQueueFDs(queueSizes{ + dataSize: defaultQueueDataSize, + txPipeSize: defaultQueuePipeSize, + rxPipeSize: defaultQueuePipeSize, + sharedDataSize: defaultSharedDataSize, + }) + + if err != nil { + closeFDs(txCfg) + return nil, fmt.Errorf("failed to create rx queue: %s", err) + } + + return &QueuePair{ + txCfg: txCfg, + rxCfg: rxCfg, + }, nil +} + +// Close closes underlying tx/rx queue fds. +func (q *QueuePair) Close() { + closeFDs(q.txCfg) + closeFDs(q.rxCfg) +} + +// TXQueueConfig returns the QueueConfig for the receive queue. +func (q *QueuePair) TXQueueConfig() QueueConfig { + return q.txCfg +} + +// RXQueueConfig returns the QueueConfig for the transmit queue. +func (q *QueuePair) RXQueueConfig() QueueConfig { + return q.rxCfg +} + +type queueSizes struct { + dataSize int64 + txPipeSize int64 + rxPipeSize int64 + sharedDataSize int64 +} + +func createQueueFDs(s queueSizes) (QueueConfig, error) { + success := false + var eventFD eventfd.Eventfd + var dataFD, txPipeFD, rxPipeFD, sharedDataFD int + defer func() { + if success { + return + } + closeFDs(QueueConfig{ + EventFD: eventFD, + DataFD: dataFD, + TxPipeFD: txPipeFD, + RxPipeFD: rxPipeFD, + SharedDataFD: sharedDataFD, + }) + }() + eventFD, err := eventfd.Create() + if err != nil { + return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err) + } + dataFD, err = createFile(s.dataSize, false) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err) + } + txPipeFD, err = createFile(s.txPipeSize, true) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err) + } + rxPipeFD, err = createFile(s.rxPipeSize, true) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err) + } + sharedDataFD, err = createFile(s.sharedDataSize, false) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err) + } + success = true + return QueueConfig{ + EventFD: eventFD, + DataFD: dataFD, + TxPipeFD: txPipeFD, + RxPipeFD: rxPipeFD, + SharedDataFD: sharedDataFD, + }, nil +} + +func createFile(size int64, initQueue bool) (fd int, err error) { + const tmpDir = "/dev/shm/" + f, err := ioutil.TempFile(tmpDir, "sharedmem_test") + if err != nil { + return -1, fmt.Errorf("TempFile failed: %v", err) + } + defer f.Close() + unix.Unlink(f.Name()) + + if initQueue { + // Write the "slot-free" flag in the initial queue. + if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil { + return -1, fmt.Errorf("WriteAt failed: %v", err) + } + } + + fd, err = unix.Dup(int(f.Fd())) + if err != nil { + return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err) + } + + if err := unix.Ftruncate(fd, size); err != nil { + unix.Close(fd) + return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err) + } + + return fd, nil +} + +func closeFDs(c QueueConfig) { + unix.Close(c.DataFD) + c.EventFD.Close() + unix.Close(c.TxPipeFD) + unix.Close(c.RxPipeFD) + unix.Close(c.SharedDataFD) +} diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go index e882a128c..87747dcc7 100644 --- a/pkg/tcpip/link/sharedmem/rx.go +++ b/pkg/tcpip/link/sharedmem/rx.go @@ -21,7 +21,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -30,7 +30,7 @@ type rx struct { data []byte sharedData []byte q queue.Rx - eventFD int + eventFD eventfd.Eventfd } // init initializes all state needed by the rx queue based on the information @@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error { // Duplicate the eventFD so that caller can close it but we can still // use it. - efd, err := unix.Dup(c.EventFD) + efd, err := c.EventFD.Dup() if err != nil { unix.Munmap(txPipe) unix.Munmap(rxPipe) @@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error { return err } - // Set the eventfd as non-blocking. - if err := unix.SetNonblock(efd, true); err != nil { - unix.Munmap(txPipe) - unix.Munmap(rxPipe) - unix.Munmap(data) - unix.Munmap(sharedData) - unix.Close(efd) - return err - } - // Initialize state based on buffers. r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData)) r.data = data @@ -105,7 +95,13 @@ func (r *rx) cleanup() { unix.Munmap(r.data) unix.Munmap(r.sharedData) - unix.Close(r.eventFD) + r.eventFD.Close() +} + +// notify writes to the tx.eventFD to indicate to the peer that there is data to +// be read. +func (r *rx) notify() { + r.eventFD.Notify() } // postAndReceive posts the provided buffers (if any), and then tries to read @@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue. if len(b) != 0 && !r.q.PostBuffers(b) { r.q.EnableNotification() for !r.q.PostBuffers(b) { - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) + r.eventFD.Wait() if atomic.LoadUint32(stopRequested) != 0 { r.q.DisableNotification() return nil, 0 @@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue. } // Wait for notification. - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) + r.eventFD.Wait() if atomic.LoadUint32(stopRequested) != 0 { r.q.DisableNotification() return nil, 0 diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go new file mode 100644 index 000000000..6ea21ffd1 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/server_rx.go @@ -0,0 +1,142 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/eventfd" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +type serverRx struct { + // packetPipe represents the receive end of the pipe that carries the packet + // descriptors sent by the client. + packetPipe pipe.Rx + + // completionPipe represents the transmit end of the pipe that will carry + // completion notifications from the server to the client. + completionPipe pipe.Tx + + // data represents the buffer area where the packet payload is held. + data []byte + + // eventFD is used to notify the peer when transmission is completed. + eventFD eventfd.Eventfd + + // sharedData the memory region to use to enable/disable notifications. + sharedData []byte +} + +// init initializes all state needed by the serverTx queue based on the +// information provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (s *serverRx) init(c *QueueConfig) error { + // Map in all buffers. + packetPipeMem, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) }) + defer cu.Clean() + + completionPipeMem, err := getBuffer(c.RxPipeFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(completionPipeMem) }) + + data, err := getBuffer(c.DataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(data) }) + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(sharedData) }) + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := c.EventFD.Dup() + if err != nil { + return err + } + cu.Add(func() { efd.Close() }) + + s.packetPipe.Init(packetPipeMem) + s.completionPipe.Init(completionPipeMem) + s.data = data + s.eventFD = efd + s.sharedData = sharedData + + cu.Release() + return nil +} + +func (s *serverRx) cleanup() { + unix.Munmap(s.packetPipe.Bytes()) + unix.Munmap(s.completionPipe.Bytes()) + unix.Munmap(s.data) + unix.Munmap(s.sharedData) + s.eventFD.Close() +} + +// completionNotificationSize is size in bytes of a completion notification sent +// on the completion queue after a transmitted packet has been handled. +const completionNotificationSize = 8 + +// receive receives a single packet from the packetPipe. +func (s *serverRx) receive() []byte { + desc := s.packetPipe.Pull() + if desc == nil { + return nil + } + + pktInfo := queue.DecodeTxPacketHeader(desc) + contents := make([]byte, 0, pktInfo.Size) + toCopy := pktInfo.Size + for i := 0; i < pktInfo.BufferCount; i++ { + txBuf := queue.DecodeTxBufferHeader(desc, i) + if txBuf.Size <= toCopy { + contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...) + toCopy -= txBuf.Size + continue + } + contents = append(contents, s.data[txBuf.Offset:][:toCopy]...) + break + } + + // Flush to let peer know that slots queued for transmission have been handled + // and its free to reuse the slots. + s.packetPipe.Flush() + // Encode packet completion. + b := s.completionPipe.Push(completionNotificationSize) + queue.EncodeTxCompletion(b, pktInfo.ID) + s.completionPipe.Flush() + return contents +} + +func (s *serverRx) waitForPackets() { + s.eventFD.Wait() +} diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go new file mode 100644 index 000000000..13a82903f --- /dev/null +++ b/pkg/tcpip/link/sharedmem/server_tx.go @@ -0,0 +1,175 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/eventfd" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +// serverTx represents the server end of the sharedmem queue and is used to send +// packets to the peer in the buffers posted by the peer in the fillPipe. +type serverTx struct { + // fillPipe represents the receive end of the pipe that carries the RxBuffers + // posted by the peer. + fillPipe pipe.Rx + + // completionPipe represents the transmit end of the pipe that carries the + // descriptors for filled RxBuffers. + completionPipe pipe.Tx + + // data represents the buffer area where the packet payload is held. + data []byte + + // eventFD is used to notify the peer when fill requests are fulfilled. + eventFD eventfd.Eventfd + + // sharedData the memory region to use to enable/disable notifications. + sharedData []byte +} + +// init initializes all tstate needed by the serverTx queue based on the +// information provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (s *serverTx) init(c *QueueConfig) error { + // Map in all buffers. + fillPipeMem, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) }) + defer cu.Clean() + + completionPipeMem, err := getBuffer(c.RxPipeFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(completionPipeMem) }) + + data, err := getBuffer(c.DataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(data) }) + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(sharedData) }) + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := c.EventFD.Dup() + if err != nil { + return err + } + cu.Add(func() { efd.Close() }) + + cu.Release() + + s.fillPipe.Init(fillPipeMem) + s.completionPipe.Init(completionPipeMem) + s.data = data + s.eventFD = efd + s.sharedData = sharedData + + return nil +} + +func (s *serverTx) cleanup() { + unix.Munmap(s.fillPipe.Bytes()) + unix.Munmap(s.completionPipe.Bytes()) + unix.Munmap(s.data) + unix.Munmap(s.sharedData) + s.eventFD.Close() +} + +// fillPacket copies the data in the provided views into buffers pulled from the +// fillPipe and returns a slice of RxBuffers that contain the copied data as +// well as the total number of bytes copied. +// +// To avoid allocations the filledBuffers are appended to the buffers slice +// which will be grown as required. +func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) { + filledBuffers = buffers[:0] + // fillBuffer copies as much of the views as possible into the provided buffer + // and returns any left over views (if any). + fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) { + if len(views) == 0 { + return nil + } + availBytes := buffer.Size + copied := uint64(0) + for availBytes > 0 && len(views) > 0 { + n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0]) + views[0].TrimFront(n) + if !views[0].IsEmpty() { + break + } + views = views[1:] + copied += uint64(n) + availBytes -= uint32(n) + } + buffer.Size = uint32(copied) + return views + } + + for len(views) > 0 { + var b []byte + // Spin till we get a free buffer reposted by the peer. + for { + if b = s.fillPipe.Pull(); b != nil { + break + } + } + rxBuffer := queue.DecodeRxBufferHeader(b) + // Copy the packet into the posted buffer. + views = fillBuffer(&rxBuffer, views) + totalCopied += rxBuffer.Size + filledBuffers = append(filledBuffers, rxBuffer) + } + + return filledBuffers, totalCopied +} + +func (s *serverTx) transmit(views []buffer.View) bool { + buffers := make([]queue.RxBuffer, 8) + buffers, totalCopied := s.fillPacket(views, buffers) + b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers))) + if b == nil { + return false + } + queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */) + for i := 0; i < len(buffers); i++ { + queue.EncodeRxCompletionBuffer(b, i, buffers[i]) + } + s.completionPipe.Flush() + s.fillPipe.Flush() + return true +} + +func (s *serverTx) notify() { + s.eventFD.Notify() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 66efe6472..b75522a51 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -24,14 +24,16 @@ package sharedmem import ( + "fmt" "sync/atomic" - "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,7 +49,7 @@ type QueueConfig struct { // EventFD is a file descriptor for the event that is signaled when // data is becomes available in this queue. - EventFD int + EventFD eventfd.Eventfd // TxPipeFD is a file descriptor for the tx pipe associated with the // queue. @@ -63,16 +65,97 @@ type QueueConfig struct { SharedDataFD int } +// FDs returns the FD's in the QueueConfig as a slice of ints. This must +// be used in conjunction with QueueConfigFromFDs to ensure the order +// of FDs matches when reconstructing the config when serialized or sent +// as part of control messages. +func (q *QueueConfig) FDs() []int { + return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD} +} + +// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each +// entry represents an file descriptor. The order of FDs in the slice must be in +// the order specified below for the config to be valid. QueueConfig.FDs() +// should be used when the config needs to be serialized or sent as part of a +// control message to ensure the correct order. +func QueueConfigFromFDs(fds []int) (QueueConfig, error) { + if len(fds) != 5 { + return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds)) + } + return QueueConfig{ + DataFD: fds[0], + EventFD: eventfd.Wrap(fds[1]), + TxPipeFD: fds[2], + RxPipeFD: fds[3], + SharedDataFD: fds[4], + }, nil +} + +// Options specify the details about the sharedmem endpoint to be created. +type Options struct { + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // BufferSize is the size of each scatter/gather buffer that will hold packet + // data. + // + // NOTE: This directly determines number of packets that can be held in + // the ring buffer at any time. This does not have to be sized to the MTU as + // the shared memory queue design allows usage of more than one buffer to be + // used to make up a given packet. + BufferSize uint32 + + // LinkAddress is the link address for this endpoint (required). + LinkAddress tcpip.LinkAddress + + // TX is the transmit queue configuration for this shared memory endpoint. + TX QueueConfig + + // RX is the receive queue configuration for this shared memory endpoint. + RX QueueConfig + + // PeerFD is the fd for the connected peer which can be used to detect + // peer disconnects. + PeerFD int + + // OnClosed is a function that is called when the endpoint is being closed + // (probably due to peer going away) + OnClosed func(err tcpip.Error) + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool +} + type endpoint struct { // mtu (maximum transmission unit) is the maximum size of a packet. + // mtu is immutable. mtu uint32 // bufferSize is the size of each individual buffer. + // bufferSize is immutable. bufferSize uint32 // addr is the local address of this endpoint. + // addr is immutable. addr tcpip.LinkAddress + // peerFD is an fd to the peer that can be used to detect when the + // peer is gone. + // peerFD is immutable. + peerFD int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // hdrSize is the size of the link layer header if any. + // hdrSize is immutable. + hdrSize uint32 + // rx is the receive queue. rx rx @@ -83,34 +166,55 @@ type endpoint struct { // Wait group used to indicate that all workers have stopped. completed sync.WaitGroup + // onClosed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + onClosed func(tcpip.Error) + // mu protects the following fields. mu sync.Mutex // tx is the transmit queue. + // +checklocks:mu tx tx // workerStarted specifies whether the worker goroutine was started. + // +checklocks:mu workerStarted bool } // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { +func New(opts Options) (stack.LinkEndpoint, error) { e := &endpoint{ - mtu: mtu, - bufferSize: bufferSize, - addr: addr, + mtu: opts.MTU, + bufferSize: opts.BufferSize, + addr: opts.LinkAddress, + peerFD: opts.PeerFD, + onClosed: opts.OnClosed, } - if err := e.tx.init(bufferSize, &tx); err != nil { + if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil { return nil, err } - if err := e.rx.init(bufferSize, &rx); err != nil { + if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil { e.tx.cleanup() return nil, err } + e.caps = stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + e.caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + e.caps |= stack.CapabilityTXChecksumOffload + } + + if opts.LinkAddress != "" { + e.hdrSize = header.EthernetMinimumSize + e.caps |= stack.CapabilityResolutionRequired + } return e, nil } @@ -119,13 +223,13 @@ func (e *endpoint) Close() { // Tell dispatch goroutine to stop, then write to the eventfd so that // it wakes up in case it's sleeping. atomic.StoreUint32(&e.stopRequested, 1) - unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + e.rx.eventFD.Notify() // Cleanup the queues inline if the worker hasn't started yet; we also // know it won't start from now on because stopRequested is set to 1. e.mu.Lock() + defer e.mu.Unlock() workerPresent := e.workerStarted - e.mu.Unlock() if !workerPresent { e.tx.cleanup() @@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { e.workerStarted = true e.completed.Add(1) + + // Spin up a goroutine to monitor for peer shutdown. + if e.peerFD >= 0 { + e.completed.Add(1) + go func() { + defer e.completed.Done() + b := make([]byte, 1) + // When sharedmem endpoint is in use the peerFD is never used for any data + // transfer and this Read should only return if the peer is shutting down. + _, err := rawfile.BlockingRead(e.peerFD, b) + if e.onClosed != nil { + e.onClosed(err) + } + }() + } + // Link endpoints are not savable. When transportation endpoints // are saved, they stop sending outgoing packets and all // incoming packets are rejected. @@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool { // MTU implements stack.LinkEndpoint.MTU. It returns the value initialized // during construction. func (e *endpoint) MTU() uint32 { - return e.mtu - header.EthernetMinimumSize + return e.mtu - e.hdrSize } // Capabilities implements stack.LinkEndpoint.Capabilities. -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps } // MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the // ethernet frame header size. -func (*endpoint) MaxHeaderLength() uint16 { - return header.EthernetMinimumSize +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) } // LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local @@ -205,17 +325,15 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WriteRawPacket implements stack.LinkEndpoint. func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) +// +checklocks:e.mu +func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.addr != "" { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } views := pkt.Views() // Transmit the packet. - e.mu.Lock() ok := e.tx.transmit(views...) - e.mu.Unlock() - if !ok { return &tcpip.ErrWouldBlock{} } @@ -223,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol return nil } +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + return err + } + e.tx.notify() + return nil +} + // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") +func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + n := 0 + var err tcpip.Error + e.mu.Lock() + defer e.mu.Unlock() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + break + } + n++ + } + // WritePackets never returns an error if it successfully transmitted at least + // one packet. + if err != nil && n == 0 { + return 0, err + } + e.tx.notify() + return n, nil } // dispatchLoop reads packets from the rx queue in a loop and dispatches them @@ -268,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { Data: buffer.View(b).ToVectorisedView(), }) - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - continue + var src, dst tcpip.LinkAddress + var proto tcpip.NetworkProtocolNumber + if e.addr != "" { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + continue + } + eth := header.Ethernet(hdr) + src = eth.SourceAddress() + dst = eth.DestinationAddress() + proto = eth.Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + proto = header.IPv4ProtocolNumber + case header.IPv6Version: + proto = header.IPv6ProtocolNumber + default: + continue + } } - eth := header.Ethernet(hdr) // Send packet up the stack. - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) + d.DeliverNetworkPacket(src, dst, proto, pkt) } + e.mu.Lock() + defer e.mu.Unlock() + // Clean state. e.tx.cleanup() e.rx.cleanup() diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go new file mode 100644 index 000000000..43c5b8c63 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go @@ -0,0 +1,344 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type serverEndpoint struct { + // mtu (maximum transmission unit) is the maximum size of a packet. + // mtu is immutable. + mtu uint32 + + // bufferSize is the size of each individual buffer. + // bufferSize is immutable. + bufferSize uint32 + + // addr is the local address of this endpoint. + // addr is immutable + addr tcpip.LinkAddress + + // rx is the receive queue. + rx serverRx + + // stopRequested is to be accessed atomically only, and determines if the + // worker goroutines should stop. + stopRequested uint32 + + // Wait group used to indicate that all workers have stopped. + completed sync.WaitGroup + + // peerFD is an fd to the peer that can be used to detect when the peer is + // gone. + // peerFD is immutable. + peerFD int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // hdrSize is the size of the link layer header if any. + // hdrSize is immutable. + hdrSize uint32 + + // onClosed is a function to be called when the FD's peer (if any) closes its + // end of the communication pipe. + onClosed func(tcpip.Error) + + // mu protects the following fields. + mu sync.Mutex + + // tx is the transmit queue. + // +checklocks:mu + tx serverTx + + // workerStarted specifies whether the worker goroutine was started. + // +checklocks:mu + workerStarted bool +} + +// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be +// broken up into buffers of "bufferSize" bytes. +func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) { + e := &serverEndpoint{ + mtu: opts.MTU, + bufferSize: opts.BufferSize, + addr: opts.LinkAddress, + peerFD: opts.PeerFD, + onClosed: opts.OnClosed, + } + + if err := e.tx.init(&opts.RX); err != nil { + return nil, err + } + + if err := e.rx.init(&opts.TX); err != nil { + e.tx.cleanup() + return nil, err + } + + e.caps = stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + e.caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + e.caps |= stack.CapabilityTXChecksumOffload + } + + if opts.LinkAddress != "" { + e.hdrSize = header.EthernetMinimumSize + e.caps |= stack.CapabilityResolutionRequired + } + + return e, nil +} + +// Close frees all resources associated with the endpoint. +func (e *serverEndpoint) Close() { + // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes + // up in case it's sleeping. + atomic.StoreUint32(&e.stopRequested, 1) + e.rx.eventFD.Notify() + + // Cleanup the queues inline if the worker hasn't started yet; we also know it + // won't start from now on because stopRequested is set to 1. + e.mu.Lock() + defer e.mu.Unlock() + workerPresent := e.workerStarted + + if !workerPresent { + e.tx.cleanup() + e.rx.cleanup() + } +} + +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. +func (e *serverEndpoint) Wait() { + e.completed.Wait() +} + +// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that +// reads packets from the rx queue. +func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { + e.workerStarted = true + e.completed.Add(1) + if e.peerFD >= 0 { + e.completed.Add(1) + // Spin up a goroutine to monitor for peer shutdown. + go func() { + b := make([]byte, 1) + // When sharedmem endpoint is in use the peerFD is never used for any + // data transfer and this Read should only return if the peer is + // shutting down. + _, err := rawfile.BlockingRead(e.peerFD, b) + if e.onClosed != nil { + e.onClosed(err) + } + e.completed.Done() + }() + } + // Link endpoints are not savable. When transportation endpoints are saved, + // they stop sending outgoing packets and all incoming packets are rejected. + go e.dispatchLoop(dispatcher) // S/R-SAFE: see above. + } + e.mu.Unlock() +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *serverEndpoint) IsAttached() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.workerStarted +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *serverEndpoint) MTU() uint32 { + return e.mtu - e.hdrSize +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the +// ethernet frame header size. +func (e *serverEndpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local +// link address. +func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + ethHdr := &header.EthernetFields{ + DstAddr: remote, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket +func (e *serverEndpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + views := pkt.Views() + e.mu.Lock() + defer e.mu.Unlock() + ok := e.tx.transmit(views) + if !ok { + return &tcpip.ErrWouldBlock{} + } + e.tx.notify() + return nil +} + +// +checklocks:e.mu +func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.addr != "" { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } + + views := pkt.Views() + ok := e.tx.transmit(views) + if !ok { + return &tcpip.ErrWouldBlock{} + } + + return nil +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + // Transmit the packet. + e.mu.Lock() + defer e.mu.Unlock() + if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + return err + } + e.tx.notify() + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *serverEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + n := 0 + var err tcpip.Error + e.mu.Lock() + defer e.mu.Unlock() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + break + } + n++ + } + // WritePackets never returns an error if it successfully transmitted at least + // one packet. + if err != nil && n == 0 { + return 0, err + } + e.tx.notify() + return n, nil +} + +// dispatchLoop reads packets from the rx queue in a loop and dispatches them +// to the network stack. +func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) { + for atomic.LoadUint32(&e.stopRequested) == 0 { + b := e.rx.receive() + if b == nil { + e.rx.waitForPackets() + continue + } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + var src, dst tcpip.LinkAddress + var proto tcpip.NetworkProtocolNumber + if e.addr != "" { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + continue + } + eth := header.Ethernet(hdr) + src = eth.SourceAddress() + dst = eth.DestinationAddress() + proto = eth.Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + proto = header.IPv4ProtocolNumber + case header.IPv6Version: + proto = header.IPv6ProtocolNumber + default: + continue + } + } + // Send packet up the stack. + d.DeliverNetworkPacket(src, dst, proto, pkt) + } + + e.mu.Lock() + defer e.mu.Unlock() + + // Clean state. + e.tx.cleanup() + e.rx.cleanup() + + e.completed.Done() +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go new file mode 100644 index 000000000..1bc58614e --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go @@ -0,0 +1,220 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem_server_test + +import ( + "fmt" + "io" + "net" + "net/http" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +const ( + localLinkAddr = "\xde\xad\xbe\xef\x56\x78" + remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" + localIPv4Address = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02") + serverPort = 10001 + + defaultMTU = 1500 + defaultBufferSize = 1500 +) + +type stackOptions struct { + ep stack.LinkEndpoint + addr tcpip.Address +} + +func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) { + st := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + AllowExternalLoopbackTraffic: true, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + AllowExternalLoopbackTraffic: true, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + }) + nicID := tcpip.NICID(1) + sniffEP := sniffer.New(stackOpts.ep) + opts := stack.NICOptions{Name: "eth0"} + if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil { + return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err) + } + + // Add Protocol Address. + protocolNum := ipv4.ProtocolNumber + routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}} + if len(stackOpts.addr) == 16 { + routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}} + protocolNum = ipv6.ProtocolNumber + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: protocolNum, + AddressWithPrefix: stackOpts.addr.WithPrefix(), + } + if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err) + } + + // Setup route table. + st.SetRouteTable(routeTable) + + return st, nil +} + +func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { + ep, err := sharedmem.New(sharedmem.Options{ + MTU: defaultMTU, + BufferSize: defaultBufferSize, + LinkAddress: localLinkAddr, + TX: qPair.TXQueueConfig(), + RX: qPair.RXQueueConfig(), + PeerFD: peerFD, + }) + if err != nil { + return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) + } + st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address}) + if err != nil { + return nil, fmt.Errorf("failed to create client stack: %s", err) + } + return st, nil +} + +func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { + ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{ + MTU: defaultMTU, + BufferSize: defaultBufferSize, + LinkAddress: remoteLinkAddr, + TX: qPair.TXQueueConfig(), + RX: qPair.RXQueueConfig(), + PeerFD: peerFD, + }) + if err != nil { + return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) + } + st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address}) + if err != nil { + return nil, fmt.Errorf("failed to create client stack: %s", err) + } + return st, nil +} + +type testContext struct { + clientStk *stack.Stack + serverStk *stack.Stack + peerFDs [2]int +} + +func newTestContext(t *testing.T) *testContext { + peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0) + if err != nil { + t.Fatalf("failed to create peerFDs: %s", err) + } + q, err := sharedmem.NewQueuePair() + if err != nil { + t.Fatalf("failed to create sharedmem queue: %s", err) + } + clientStack, err := newClientStack(t, q, peerFDs[0]) + if err != nil { + q.Close() + unix.Close(peerFDs[0]) + unix.Close(peerFDs[1]) + t.Fatalf("failed to create client stack: %s", err) + } + serverStack, err := newServerStack(t, q, peerFDs[1]) + if err != nil { + q.Close() + unix.Close(peerFDs[0]) + unix.Close(peerFDs[1]) + clientStack.Close() + t.Fatalf("failed to create server stack: %s", err) + } + return &testContext{ + clientStk: clientStack, + serverStk: serverStack, + peerFDs: peerFDs, + } +} + +func (ctx *testContext) cleanup() { + unix.Close(ctx.peerFDs[0]) + unix.Close(ctx.peerFDs[1]) + ctx.clientStk.Close() + ctx.serverStk.Close() +} + +func TestServerRoundTrip(t *testing.T) { + ctx := newTestContext(t) + defer ctx.cleanup() + listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort} + l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber) + if err != nil { + t.Fatalf("failed to start TCP Listener: %s", err) + } + defer l.Close() + var responseString = "response" + go func() { + http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(responseString)) + })) + }() + + dialFunc := func(address, protocol string) (net.Conn, error) { + return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + Dial: dialFunc, + }, + } + serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort) + response, err := httpClient.Get(serverURL) + if err != nil { + t.Fatalf("httpClient.Get(\"/\") failed: %s", err) + } + if got, want := response.StatusCode, http.StatusOK; got != want { + t.Fatalf("unexpected status code got: %d, want: %d", got, want) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("io.ReadAll(response.Body) failed: %s", err) + } + response.Body.Close() + if got, want := string(body), responseString; got != want { + t.Fatalf("unexpected response got: %s, want: %s", got, want) + } +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index d6d953085..a49f5f87d 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -19,9 +19,7 @@ package sharedmem import ( "bytes" - "io/ioutil" "math/rand" - "os" "strings" "testing" "time" @@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress t: t, packetCh: make(chan struct{}, 1000000), } - c.txCfg = createQueueFDs(t, queueSizes{ + c.txCfg, err = createQueueFDs(queueSizes{ dataSize: queueDataSize, txPipeSize: queuePipeSize, rxPipeSize: queuePipeSize, sharedDataSize: 4096, }) - - c.rxCfg = createQueueFDs(t, queueSizes{ + if err != nil { + t.Fatalf("createQueueFDs for tx failed: %s", err) + } + c.rxCfg, err = createQueueFDs(queueSizes{ dataSize: queueDataSize, txPipeSize: queuePipeSize, rxPipeSize: queuePipeSize, sharedDataSize: 4096, }) + if err != nil { + t.Fatalf("createQueueFDs for rx failed: %s", err) + } initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(Options{ + MTU: mtu, + BufferSize: bufferSize, + LinkAddress: addr, + TX: c.txCfg, + RX: c.rxCfg, + PeerFD: -1, + }) if err != nil { t.Fatalf("New failed: %v", err) } @@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip. func (c *testContext) cleanup() { c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) + closeFDs(c.txCfg) + closeFDs(c.rxCfg) c.txq.cleanup() c.rxq.cleanup() } @@ -191,69 +201,6 @@ func shuffle(b []int) { } } -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir, ok := os.LookupEnv("TEST_TMPDIR") - if !ok { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - unix.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := unix.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := unix.Ftruncate(fd, size); err != nil { - unix.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - unix.Close(c.DataFD) - unix.Close(c.EventFD) - unix.Close(c.TxPipeFD) - unix.Close(c.RxPipeFD) - unix.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - // TestSimpleSend sends 1000 packets with random header and payload sizes, // then checks that the right payload is received on the shared memory queues. func TestSimpleSend(t *testing.T) { @@ -263,6 +210,7 @@ func TestSimpleSend(t *testing.T) { // Prepare route. var r stack.RouteInfo r.RemoteLinkAddress = remoteLinkAddr + r.LocalLinkAddress = localLinkAddr for iters := 1000; iters > 0; iters-- { func() { @@ -280,8 +228,11 @@ func TestSimpleSend(t *testing.T) { Data: data.ToVectorisedView(), }) copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + // Every PacketBuffer must have these set: + // See nic.writePacket. + pkt.EgressRoute = r + pkt.NetworkProtocolNumber = proto if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -350,8 +301,11 @@ func TestPreserveSrcAddressInSend(t *testing.T) { // the minimum size of the ethernet header. ReserveHeaderBytes: header.EthernetMinimumSize, }) - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + // Every PacketBuffer must have these set: + // See nic.writePacket. + pkt.EgressRoute = r + pkt.NetworkProtocolNumber = proto if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -672,7 +626,7 @@ func TestSimpleReceive(t *testing.T) { // Push completion. c.pushRxCompletion(uint32(len(contents)), bufs) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") @@ -718,7 +672,7 @@ func TestRxBuffersReposted(t *testing.T) { // Complete the buffer. c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for it to be reposted. bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) @@ -734,7 +688,7 @@ func TestRxBuffersReposted(t *testing.T) { // Complete with two buffers. c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for them to be reposted. for j := 0; j < 2; j++ { @@ -759,7 +713,7 @@ func TestReceivePostingIsFull(t *testing.T) { first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that packet is received. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") @@ -768,7 +722,7 @@ func TestReceivePostingIsFull(t *testing.T) { second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that no packet is received yet, as the worker is blocked trying // to repost. @@ -781,7 +735,7 @@ func TestReceivePostingIsFull(t *testing.T) { // Flush tx queue, which will allow the first buffer to be reposted, // and the second completion to be pulled. c.rxq.tx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that second packet completes. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") @@ -803,7 +757,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) { bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for packet to be indicated. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go index f7e816a41..d974c266e 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go @@ -15,7 +15,12 @@ package sharedmem import ( + "fmt" + "reflect" "unsafe" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/memutil" ) // sharedDataPointer converts the shared data slice into a pointer so that it @@ -23,3 +28,31 @@ import ( func sharedDataPointer(sharedData []byte) *uint32 { return (*uint32)(unsafe.Pointer(&sharedData[0:4][0])) } + +// getBuffer returns a memory region mapped to the full contents of the given +// file descriptor. +func getBuffer(fd int) ([]byte, error) { + var s unix.Stat_t + if err := unix.Fstat(fd, &s); err != nil { + return nil, err + } + + // Check that size doesn't overflow an int. + if s.Size > int64(^uint(0)>>1) { + return nil, unix.EDOM + } + + addr, err := memutil.MapFile(0 /* addr */, uintptr(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE, uintptr(fd), 0 /*offset*/) + if err != nil { + return nil, fmt.Errorf("failed to map memory for buffer fd: %d, error: %s", fd, err) + } + + // Use unsafe to conver addr into a []byte. + var b []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + hdr.Data = addr + hdr.Len = int(s.Size) + hdr.Cap = int(s.Size) + + return b, nil +} diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index e3210051f..d6c61afee 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -28,10 +29,12 @@ const ( // tx holds all state associated with a tx queue. type tx struct { - data []byte - q queue.Tx - ids idManager - bufs bufferManager + data []byte + q queue.Tx + ids idManager + bufs bufferManager + eventFD eventfd.Eventfd + sharedDataFD int } // init initializes all state needed by the tx queue based on the information @@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error { t.ids.init() t.bufs.init(0, len(data), int(mtu)) t.data = data - + t.eventFD = c.EventFD + t.sharedDataFD = c.SharedDataFD return nil } @@ -142,20 +146,10 @@ func (t *tx) transmit(bufs ...buffer.View) bool { return true } -// getBuffer returns a memory region mapped to the full contents of the given -// file descriptor. -func getBuffer(fd int) ([]byte, error) { - var s unix.Stat_t - if err := unix.Fstat(fd, &s); err != nil { - return nil, err - } - - // Check that size doesn't overflow an int. - if s.Size > int64(^uint(0)>>1) { - return nil, unix.EDOM - } - - return unix.Mmap(fd, 0, int(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE) +// notify writes to the tx.eventFD to indicate to the peer that there is data to +// be read. +func (t *tx) notify() { + t.eventFD.Notify() } // idDescriptor is used by idManager to either point to a tx buffer (in case diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6515c31e5..e08243547 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -272,7 +272,6 @@ type protocol struct { func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } func (p *protocol) MinimumPacketSize() int { return header.ARPSize } -func (p *protocol) DefaultPrefixLen() int { return 0 } func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 5fcbfeaa2..061cc35ae 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext t.Fatalf("CreateNIC failed: %s", err) } - if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %s", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: stackAddr.WithPrefix(), + } + if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } tc.s.SetRouteTable([]tcpip.Route{{ @@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 2179302d3..87f650661 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -233,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv4.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, @@ -249,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv6.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, @@ -272,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c } v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) + if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err) } v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) + if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err) } return s, e @@ -713,8 +725,8 @@ func TestReceive(t *testing.T) { if !ok { t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) } - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err) } else { ep.DecRef() } @@ -885,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -971,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1237,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv6Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1304,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name string protoFactory stack.NetworkProtocolFactory protoNum tcpip.NetworkProtocolNumber - nicAddr tcpip.Address + nicAddr tcpip.AddressWithPrefix remoteAddr tcpip.Address pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) @@ -1314,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1355,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with IHL too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1379,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1397,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 minimum size", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1433,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) @@ -1478,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options and data across views", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) @@ -1519,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) @@ -1559,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 with extension header", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) @@ -1604,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 minimum size", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1639,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 too small", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1663,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }{ { name: "unspecified source", - srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))), }, { name: "random source", - srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))), }, } @@ -1680,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.protoNum, + AddressWithPrefix: test.nicAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) - r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err) } defer r.Release() @@ -2072,8 +2088,12 @@ func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddressWithPrefix(nicID, test.proto, test.addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, test.proto, test.addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.proto, + AddressWithPrefix: test.addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 2aa38eb98..3eff0bbd8 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -167,23 +167,22 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet p := hdr.TransportProtocol() dstAddr := hdr.DestinationAddress() // Skip the ip header, then deliver the error. - pkt.Data().DeleteFront(hlen) + if _, ok := pkt.Data().Consume(hlen); !ok { + panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen)) + } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived - // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. - v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) - if !ok { + h := header.ICMPv4(pkt.TransportHeader().View()) + if len(h) < header.ICMPv4MinimumSize { received.invalid.Increment() return } - h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. - if pkt.Data().AsRange().Checksum() != 0xffff { + if header.Checksum(h, pkt.Data().AsRange().Checksum()) != 0xffff { received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because @@ -240,20 +239,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.echoRequest.Increment() - sent := e.stats.icmp.packetsSent - if !e.protocol.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return - } - // DeliverTransportPacket will take ownership of pkt so don't use it beyond // this point. Make a deep copy of the data before pkt gets sent as we will - // be modifying fields. + // be modifying fields. Both the ICMP header (with its type modified to + // EchoReply) and payload are reused in the reply packet. // // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. - replyData := pkt.Data().AsRange().ToOwnedView() + replyData := stack.PayloadSince(pkt.TransportHeader()) ipHdr := header.IPv4(pkt.NetworkHeader().View()) localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast @@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } defer r.Release() + sent := e.stats.icmp.packetsSent + if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) { + sent.rateLimited.Increment() + return + } + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -331,6 +331,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.echoReply.Increment() + // ICMP sockets expect the ICMP header to be present, so we don't consume + // the ICMP header. e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: @@ -338,7 +340,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { mtu := h.MTU() code := h.Code() - pkt.Data().DeleteFront(header.ICMPv4MinimumSize) switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) @@ -562,31 +563,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { - // TODO(gvisor.dev/issue/3810): - // Unfortunately the current stack pretty much always has ICMPv4 headers - // in the Data section of the packet but there is no guarantee that is the - // case. If this is the case grab the header to make it like all other - // packet types. When this is cleaned up the Consume should be removed. - if transportHeader.IsEmpty() { - var ok bool - transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize) - if !ok { - return nil - } - } else if transportHeader.Size() < header.ICMPv4MinimumSize { - return nil - } // We need to decide to explicitly name the packets we can respond to or // the ones we can not respond to. The decision is somewhat arbitrary and // if problems arise this could be reversed. It was judged less of a breach @@ -606,6 +586,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } } + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) { + switch reason := reason.(type) { + case *icmpReasonPortUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonProtoUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetworkUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0 + case *icmpReasonFragmentationNeeded: + return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0 + case *icmpReasonTTLExceeded: + return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0 + case *icmpReasonParamProblem: + return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + }() + + if !p.allowICMPReply(icmpType, icmpCode) { + sent.rateLimited.Increment() + return nil + } + // Now work out how much of the triggering packet we should return. // As per RFC 1812 Section 4.3.2.3 // @@ -658,44 +667,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonProtoUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetworkUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4NetUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4HostUnreachable) - counter = sent.dstUnreachable - case *icmpReasonFragmentationNeeded: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) - counter = sent.dstUnreachable - case *icmpReasonTTLExceeded: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.timeExceeded - case *icmpReasonParamProblem: - icmpHdr.SetType(header.ICMPv4ParamProblem) - icmpHdr.SetCode(header.ICMPv4UnusedCode) - icmpHdr.SetPointer(reason.pointer) - counter = sent.paramProblem - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetCode(icmpCode) + icmpHdr.SetType(icmpType) + icmpHdr.SetPointer(pointer) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 4bd6f462e..c6576fcbc 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // cycles. func TestIGMPV1Present(t *testing.T) { e, s, clock := createStack(t, true) - addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength} - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { @@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) { // The initial set of IGMP reports that were queued should be sent once an // address is assigned. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackAddr, + PrefixLen: defaultPrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if got := reportStat.Value(); got != 1 { t.Errorf("got reportStat.Value() = %d, want = 1", got) @@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) { t.Run(test.name, func(t *testing.T) { e, s, _ := createStack(t, true) for _, address := range test.stackAddresses { - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: address, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } stats := s.Stats() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e2472c851..d1d509702 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { return nil } +func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + ep, ok := p.mu.eps[id] + return ep, ok +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -240,7 +247,7 @@ func (e *endpoint) Enable() tcpip.Error { } // Create an endpoint to receive broadcast packets on this interface. - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { return err } @@ -419,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -432,7 +439,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // We should do this for every packet, rather than only NATted packets, but // removing this check short circuits broadcasts before they are sent out to // other hosts. - if pkt.NatDone { + if pkt.DNATDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we @@ -459,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -542,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -569,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -710,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -737,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. - newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader())) + newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) + newHdr := header.IPv4(newPkt.NetworkHeader().View()) // As per RFC 791 page 30, Time to Live, // @@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // Even if no local information is available on the time actually // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) + // We perform a full checksum as we may have updated options above. The IP + // header is relatively small so this is not expected to be an expensive + // operation. + newHdr.SetChecksum(0) + newHdr.SetChecksum(^newHdr.CalculateChecksum()) + + forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID()) + if !ok { + // The interface was removed after we obtained the route. + return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}} + } - switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(newHdr).ToVectorisedView(), - IsForwardedPacket: true, - })); err.(type) { + switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) { case nil: return nil case *tcpip.ErrMessageTooLong: @@ -826,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -925,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -969,7 +984,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, } proto := h.Protocol() - resPkt, _, ready, err := e.protocol.fragmentation.Process( + resPkt, transProtoNum, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -1000,6 +1015,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, h.SetTotalLength(uint16(pkt.Data().Size() + len(h))) h.SetFlagsFragmentOffset(0, 0) + e.protocol.parseTransport(pkt, tcpip.TransportProtocolNumber(transProtoNum)) + // Now that the packet is reassembled, it can be sent to raw sockets. e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } @@ -1075,11 +1092,11 @@ func (e *endpoint) Close() { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err == nil { e.mu.igmp.sendQueuedReports() } @@ -1200,6 +1217,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv4Type]struct{} } // defaultTTL is the current default TTL for the protocol. Only the @@ -1226,11 +1246,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv4MinimumSize } -// DefaultPrefixLen returns the IPv4 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv4AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) @@ -1297,19 +1312,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) } if hasTransportHdr { - switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { - case stack.ParsedOK: - case stack.UnknownTransportProtocol, stack.TransportLayerParseError: - // The transport layer will handle unknown protocols and transport layer - // parsing errors. - default: - panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) - } + p.parseTransport(pkt, transProtoNum) } return h, true } +func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) { + if transProtoNum == header.ICMPv4ProtocolNumber { + // The transport layer will handle transport layer parsing errors. + _ = parse.ICMPv4(pkt) + return + } + + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } +} + // Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { @@ -1320,6 +1345,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type and code may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool { + // Mimic linux and never rate limit for PMTU discovery. + // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288 + if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded { + return true + } + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { @@ -1399,6 +1441,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) + // Set ICMP rate limiting to Linux defaults. + // See https://man7.org/linux/man-pages/man7/icmp.7.html. + p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{ + header.ICMPv4DstUnreachable: struct{}{}, + header.ICMPv4SrcQuench: struct{}{}, + header.ICMPv4TimeExceeded: struct{}{}, + header.ICMPv4ParamProblem: struct{}{}, + } return p } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 73407be67..ef91245d7 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) { defer ep.Close() // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if err := ep.Connect(randomAddr); err != nil { t.Errorf("Connect failed: %v", err) @@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err) } expectedEmittedPacketCount := 1 @@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } // Default routes for IPv4 so ICMP can find a route to the remote @@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } for _, f := range test.fragments { @@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) { const ( nicID = 1 - addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 + addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1 + addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2 + addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3 ) // Build and return a UDP header containing payload. @@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" + src = tcpip.Address("\x10\x00\x00\x01") + dst = tcpip.Address("\x10\x00\x00\x02") ) - if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask(header.IPv4Broadcast) @@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ @@ -3349,3 +3373,139 @@ func TestCloseLocking(t *testing.T) { } }() } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv4ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.UDPProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f99cbf8f3..f814926a3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -51,6 +51,7 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 94caaae6c..adfc8d8da 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().DeleteFront(header.IPv6MinimumSize) + if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok { + panic("could not consume IPv6MinimumSize bytes") + } if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) + if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok { + panic("could not consume IPv6FragmentHeaderSize bytes") + } } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) @@ -270,7 +274,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD { return false } - if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { + if pkt.TransportHeader().View().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { return false } if iph.HopLimit() != header.MLDHopLimit { @@ -285,20 +289,17 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) { sent := e.stats.icmp.packetsSent received := e.stats.icmp.packetsReceived - // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. - v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) - if !ok { + h := header.ICMPv6(pkt.TransportHeader().View()) + if len(h) < header.ICMPv6MinimumSize { received.invalid.Increment() return } - h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) srcAddr := iph.SourceAddress() dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. - payload := pkt.Data().AsRange().SubRange(len(h)) + payload := pkt.Data().AsRange() if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: h, Src: srcAddr, @@ -325,28 +326,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.packetTooBig.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { - received.invalid.Increment() - return - } - networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + networkMTU, err := calculateNetworkMTU(h.MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } - pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { - received.invalid.Increment() - return - } - code := header.ICMPv6(hdr).Code() - pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) - switch code { + switch h.Code() { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -354,16 +342,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() - if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize { + if !isNDPValid() || len(h) < header.ICMPv6NeighborSolicitMinimumSize { received.invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.AsView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and AsView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.AsView()) + ns := header.NDPNeighborSolicit(h.MessageBody()) targetAddr := ns.TargetAddress() // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast @@ -576,16 +560,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6NeighborAdvert: received.neighborAdvert.Increment() - if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize { + if !isNDPValid() || len(h) < header.ICMPv6NeighborAdvertMinimumSize { received.invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.AsView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and AsView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.AsView()) + na := header.NDPNeighborAdvert(h.MessageBody()) it, err := na.Options().Iter(false /* check */) if err != nil { @@ -672,12 +652,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoRequest: received.echoRequest.Increment() - icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) - if !ok { - received.invalid.Increment() - return - } - // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := dstAddr @@ -692,13 +666,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } defer r.Release() + if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) { + sent.rateLimited.Increment() + return + } + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data().ExtractVV(), }) icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - copy(icmp, icmpHdr) + copy(icmp, h) icmp.SetType(header.ICMPv6EchoReply) dataRange := replyPkt.Data().AsRange() icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -720,7 +699,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoReply: received.echoReply.Increment() - if pkt.Data().Size() < header.ICMPv6EchoMinimumSize { + if len(h) < header.ICMPv6EchoMinimumSize { received.invalid.Increment() return } @@ -740,7 +719,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Solictation? - if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { + if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.invalid.Increment() return } @@ -750,9 +729,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and AsView() - // will not incur allocations. - rs := header.NDPRouterSolicit(payload.AsView()) + rs := header.NDPRouterSolicit(h.MessageBody()) it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -796,7 +773,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Advertisement? - if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { received.invalid.Increment() return } @@ -810,9 +787,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and AsView() - // will not incur allocations. - ra := header.NDPRouterAdvert(payload.AsView()) + ra := header.NDPRouterAdvert(h.MessageBody()) it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -890,11 +865,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType { case header.ICMPv6MulticastListenerQuery: e.mu.Lock() - e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView())) + e.mu.mld.handleMulticastListenerQuery(header.MLD(h.MessageBody())) e.mu.Unlock() case header.ICMPv6MulticastListenerReport: e.mu.Lock() - e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView())) + e.mu.mld.handleMulticastListenerReport(header.MLD(h.MessageBody())) e.mu.Unlock() case header.ICMPv6MulticastListenerDone: default: @@ -1174,28 +1149,37 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { - // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. - // Unfortunately at this time ICMP Packets do not have a transport - // header separated out. It is in the Data part so we need to - // separate it out now. We will just pretend it is a minimal length - // ICMP packet as we don't really care if any later bits of a - // larger ICMP packet are in the header view or in the Data view. - transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize) - if !ok { + if typ := header.ICMPv6(pkt.TransportHeader().View()).Type(); typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { return nil } - typ := header.ICMPv6(transport).Type() - if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { - return nil + } + + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) { + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer + case *icmpReasonPortUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0 + case *icmpReasonPacketTooBig: + return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0 + case *icmpReasonHopLimitExceeded: + return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0 + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } + }() + + if !p.allowICMPReply(icmpType) { + sent.rateLimited.Increment() + return nil } network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() @@ -1232,40 +1216,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonParameterProblem: - icmpHdr.SetType(header.ICMPv6ParamProblem) - icmpHdr.SetCode(reason.code) - icmpHdr.SetTypeSpecific(reason.pointer) - counter = sent.paramProblem - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6AddressUnreachable) - counter = sent.dstUnreachable - case *icmpReasonPacketTooBig: - icmpHdr.SetType(header.ICMPv6PacketTooBig) - icmpHdr.SetCode(header.ICMPv6UnusedCode) - counter = sent.packetTooBig - case *icmpReasonHopLimitExceeded: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) - counter = sent.timeExceeded - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetType(icmpType) + icmpHdr.SetCode(icmpCode) + icmpHdr.SetTypeSpecific(typeSpecific) + dataRange := newPkt.Data().AsRange() icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 7c2a3e56b..03d9f425c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -225,8 +226,8 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -407,8 +408,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } - if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) + llProtocolAddr0 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err) } c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) @@ -416,8 +421,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) + llProtocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr1.WithPrefix(), + } + if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err) } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -690,8 +699,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -883,8 +896,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1065,8 +1082,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1240,8 +1261,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -1411,12 +1436,14 @@ func TestPacketQueing(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, Clock: clock, }) + // Make sure ICMP rate limiting doesn't get in our way. + s.SetICMPLimit(rate.Inf) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1669,8 +1696,12 @@ func TestCallsToNeighborCache(t *testing.T) { if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } { @@ -1704,8 +1735,8 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index d4bd61748..7d3e1fd53 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -761,7 +761,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // We should do this for every packet, rather than only NATted packets, but // removing this check short circuits broadcasts before they are sent out to // other hosts. - if pkt.NatDone { + if pkt.DNATDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we @@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. - newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader())) + newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) + newHdr := header.IPv6(newPkt.NetworkHeader().View()) // As per RFC 8200 section 3, // @@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(newHdr).ToVectorisedView(), - IsForwardedPacket: true, - })); err.(type) { + forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID()) + if !ok { + // The interface was removed after we obtained the route. + return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}} + } + + switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) { case nil: return nil case *tcpip.ErrMessageTooLong: @@ -1097,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1180,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1534,27 +1537,36 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. - // Calculate the number of octets parsed from data. We want to remove all - // the data except the unparsed portion located at the end, which its size - // is extHdr.Buf.Size(). + // Calculate the number of octets parsed from data. We want to consume all + // the data except the unparsed portion located at the end, whose size is + // extHdr.Buf.Size(). trim := pkt.Data().Size() - extHdr.Buf.Size() // For unfragmented packets, extHdr still contains the transport header. - // Get rid of it. + // Consume that too. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. trim += pkt.TransportHeader().View().Size() - pkt.Data().DeleteFront(trim) + if _, ok := pkt.Data().Consume(trim); !ok { + stats.MalformedPacketsReceived.Increment() + return fmt.Errorf("could not consume %d bytes", trim) + } + + proto := tcpip.TransportProtocolNumber(extHdr.Identifier) + // If the packet was reassembled from a fragment, it will not have a + // transport header set yet. + if pkt.TransportHeader().View().IsEmpty() { + e.protocol.parseTransport(pkt, proto) + } stats.PacketsDelivered.Increment() - if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - pkt.TransportProtocolNumber = p + if proto == header.ICMPv6ProtocolNumber { e.handleICMP(pkt, hasFragmentHeader, routerAlert) } else { stats.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(proto, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -1628,12 +1640,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. e.mu.Lock() defer e.mu.Unlock() - return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) + return e.addAndAcquirePermanentAddressLocked(addr, properties) } // addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but @@ -1643,8 +1655,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // solicited-node multicast group and start duplicate address detection. // // Precondition: e.mu must be write locked. -func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err != nil { return nil, err } @@ -1987,6 +1999,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv6Type]struct{} } ids []uint32 @@ -1998,7 +2013,8 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - fragmentation *fragmentation.Fragmentation + fragmentation *fragmentation.Fragmentation + icmpRateLimiter *stack.ICMPRateLimiter } // Number returns the ipv6 protocol number. @@ -2011,11 +2027,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen returns the IPv6 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) @@ -2087,6 +2098,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { return nil } +func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + ep, ok := p.mu.eps[id] + return ep, ok +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -2149,19 +2167,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) } if hasTransportHdr { - switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { - case stack.ParsedOK: - case stack.UnknownTransportProtocol, stack.TransportLayerParseError: - // The transport layer will handle unknown protocols and transport layer - // parsing errors. - default: - panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) - } + p.parseTransport(pkt, transProtoNum) } return h, true } +func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) { + if transProtoNum == header.ICMPv6ProtocolNumber { + // The transport layer will handle transport layer parsing errors. + _ = parse.ICMPv6(pkt) + return + } + + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } +} + // Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) @@ -2172,6 +2200,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, @@ -2268,6 +2308,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) p.SetDefaultTTL(DefaultTTL) + // Set default ICMP rate limiting to Linux defaults. + // + // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big) + // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt. + defaultIcmpTypes := make(map[header.ICMPv6Type]struct{}) + for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ { + switch i { + case header.ICMPv6PacketTooBig: + // Do not rate limit packet too big by default. + default: + defaultIcmpTypes[i] = struct{}{} + } + } + p.mu.icmpRateLimitedTypes = defaultIcmpTypes + return p } } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index d2a23fd4f..e5286081e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -41,12 +41,12 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" - addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02") + addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03") // Tests use the extension header identifier values as uint8 instead of // header.IPv6ExtensionHeaderIdentifier. @@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { // addr2/addr3 yet as we haven't added those addresses. test.rxf(t, s, e, addr1, snmc, 0) - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err) } // Should receive a packet destined to the solicited node address of // addr2/addr3 now that we have added added addr2. test.rxf(t, s, e, addr1, snmc, 1) - if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr3.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err) } // Should still receive a packet destined to the solicited node address of @@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil { @@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Add a default route so that a return packet knows where to go. @@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) { func TestFragmentReassemblyTimeout(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") ) - if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") @@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err) } outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") @@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -3341,7 +3373,8 @@ func TestForwarding(t *testing.T) { ipHeaderLength := header.IPv6MinimumSize icmpHeaderLength := header.ICMPv6MinimumSize - totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen + payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen + totalLength := ipHeaderLength + payloadLength hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) @@ -3359,7 +3392,7 @@ func TestForwarding(t *testing.T) { copy(hdr.Prepend(extHdrLen), extHdrBytes) ip := header.IPv6(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength), + PayloadLength: uint16(payloadLength), TransportProtocol: transportProtocol, HopLimit: test.TTL, SrcAddr: test.sourceAddr, @@ -3489,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) { t.Error(err) } } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::2").To16()), + PrefixLen: 64, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv6ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + + // Calculate the UDP checksum and set it. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(nil, sum) + udpH.SetChecksum(^udpH.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index bc9cf6999..3e5c438d3 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) { // Note, we will still expect to send a report for the global address's // solicited node address from the unspecified address as per RFC 3590 // section 4. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + globalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: globalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err) } reportCounter++ if got := reportStat.Value(); got != reportCounter { @@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) { // Adding a link-local address should send a report for its solicited node // address and globalMulticastAddr. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + linkLocalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err) } if dadResolutionTime != 0 { reportCounter++ @@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8837d66d8..938427420 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{ + PEB: stack.FirstPrimaryEndpoint, + ConfigType: configType, + Deprecated: deprecated, + }) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index f0186c64e..8297a7e10 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) @@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) @@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) @@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize @@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) { checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), )) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } checkDADMsg() diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 1b96b1fb8..26640b7ee 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addr := tcpip.AddressWithPrefix{ - Address: stackIPv4Addr, - PrefixLen: defaultIPv4PrefixLength, + addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackIPv4Addr, + PrefixLen: defaultIPv4PrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, clock diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 009cab643..05b879543 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -146,8 +146,12 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Add default route. diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index c10b19aa0..a72afadda 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -124,13 +124,13 @@ func main() { log.Fatalf("Bad IP address: %v", addrName) } - var addr tcpip.Address + var addrWithPrefix tcpip.AddressWithPrefix var proto tcpip.NetworkProtocolNumber if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) + addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix() proto = ipv4.ProtocolNumber } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) + addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix() proto = ipv6.ProtocolNumber } else { log.Fatalf("Unknown IP type: %v", addrName) @@ -176,11 +176,15 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) + subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address)))) if err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 34ac62444..b0b2d0afd 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -170,10 +170,14 @@ type SocketOptions struct { // message is passed with incoming packets. receiveTClassEnabled uint32 - // receivePacketInfoEnabled is used to specify if more inforamtion is - // provided with incoming packets such as interface index and address. + // receivePacketInfoEnabled is used to specify if more information is + // provided with incoming IPv4 packets. receivePacketInfoEnabled uint32 + // receivePacketInfoEnabled is used to specify if more information is + // provided with incoming IPv6 packets. + receiveIPv6PacketInfoEnabled uint32 + // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets // being written have an IP header and the endpoint should not attach an IP // header. @@ -360,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) { storeAtomicBool(&so.receivePacketInfoEnabled, v) } +// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option. +func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool { + return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0 +} + +// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option. +func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) { + storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v) +} + // GetHeaderIncluded gets value for IP_HDRINCL option. func (so *SocketOptions) GetHeaderIncluded() bool { return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0 diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 6c42ab29b..ead36880f 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -48,7 +48,6 @@ go_library( "hook_string.go", "icmp_rate_limit.go", "iptables.go", - "iptables_state.go", "iptables_targets.go", "iptables_types.go", "neighbor_cache.go", @@ -133,6 +132,7 @@ go_test( name = "stack_test", size = "small", srcs = [ + "conntrack_test.go", "forwarding_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ae0bb4ace..7e4b5bf74 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We now promote the address. for i, s := range a.mu.primary { if s == addrState { - switch peb { + switch properties.PEB { case CanBePrimaryEndpoint: // The address is already in the primary address list. attemptAddToPrimary = false @@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address case NeverPrimaryEndpoint: a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } break } @@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // Acquire the address before returning it. addrState.mu.refs++ - addrState.mu.deprecated = deprecated - addrState.mu.configType = configType + addrState.mu.deprecated = properties.Deprecated + addrState.mu.configType = properties.ConfigType if attemptAddToPrimary { - switch peb { + switch properties.PEB { case NeverPrimaryEndpoint: case CanBePrimaryEndpoint: a.mu.primary = append(a.mu.primary, addrState) @@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address a.mu.primary[0] = addrState } default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } } @@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() - ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */) if err != nil { // addAndAcquireAddressLocked only returns an error if the address is // already assigned but we just checked above if the address exists so we // expect no error. - panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err)) } // From https://golang.org/doc/faq#nil_error: diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 140f146f6..c55f85743 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { } { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err) } // We don't need the address endpoint. ep.DecRef() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..a3f403855 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -37,23 +37,9 @@ import ( // Our hash table has 16K buckets. const numBuckets = 1 << 14 -// Direction of the tuple. -type direction int - -const ( - dirOriginal direction = iota - dirReply -) - -// Manipulation type for the connection. -// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and -// DNAT at the same time. -type manipType int - const ( - manipNone manipType = iota - manipSource - manipDestination + establishedTimeout time.Duration = 5 * 24 * time.Hour + unestablishedTimeout time.Duration = 120 * time.Second ) // tuple holds a connection's identifying and manipulating data in one @@ -64,13 +50,22 @@ type tuple struct { // tupleEntry is used to build an intrusive list of tuples. tupleEntry - tupleID - // conn is the connection tracking entry this tuple belongs to. conn *conn - // direction is the direction of the tuple. - direction direction + // reply is true iff the tuple's direction is opposite that of the first + // packet seen on the connection. + reply bool + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + tupleID tupleID +} + +func (t *tuple) id() tupleID { + t.mu.RLock() + defer t.mu.RUnlock() + return t.tupleID } // tupleID uniquely identifies a connection in one direction. It currently @@ -103,50 +98,43 @@ func (ti tupleID) reply() tupleID { // // +stateify savable type conn struct { + ct *ConnTrack + // original is the tuple in original direction. It is immutable. original tuple - // reply is the tuple in reply direction. It is immutable. + // reply is the tuple in reply direction. reply tuple - // manip indicates if the packet should be manipulated. It is immutable. - // TODO(gvisor.dev/issue/5696): Support updating manipulation type. - manip manipType - - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. It is immutable. - tcbHook Hook - - // mu protects all mutable state. - mu sync.Mutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // Indicates that the connection has been finalized and may handle replies. + // + // +checklocks:mu + finalized bool + // sourceManip indicates the packet's source is manipulated. + // + // +checklocks:mu + sourceManip bool + // destinationManip indicates the packet's destination is manipulated. + // + // +checklocks:mu + destinationManip bool // tcb is TCB control block. It is used to keep track of states - // of tcp connection and is protected by mu. + // of tcp connection. + // + // +checklocks:mu tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and - // is updated by each packet on the connection. It is protected by mu. + // is updated by each packet on the connection. // - // TODO(gvisor.dev/issue/5939): do not use the ambient clock. - lastUsed time.Time `state:".(unixTime)"` -} - -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn + // +checklocks:mu + lastUsed tcpip.MonotonicTime } // timedOut returns whether the connection timed out based on its state. -func (cn *conn) timedOut(now time.Time) bool { - const establishedTimeout = 5 * 24 * time.Hour - const defaultTimeout = 120 * time.Second - cn.mu.Lock() - defer cn.mu.Unlock() +func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { + cn.mu.RLock() + defer cn.mu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { // Use the same default as Linux, which doesn't delete // established connections for 5(!) days. @@ -154,22 +142,31 @@ func (cn *conn) timedOut(now time.Time) bool { } // Use the same default as Linux, which lets connections in most states // other than established remain for <= 120 seconds. - return now.Sub(cn.lastUsed) > defaultTimeout + return now.Sub(cn.lastUsed) > unestablishedTimeout } // update the connection tracking state. // -// Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +// +checklocks:cn.mu +func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) - } else if hook == cn.tcbHook { - cn.tcb.UpdateStateOutbound(tcpHeader) - } else { + return + } + + if reply { cn.tcb.UpdateStateInbound(tcpHeader) + } else { + cn.tcb.UpdateStateOutbound(tcpHeader) } } @@ -194,44 +191,37 @@ type ConnTrack struct { // It is immutable. seed uint32 + // clock provides timing used to determine conntrack reapings. + clock tcpip.Clock + + mu sync.RWMutex `state:"nosave"` // mu protects the buckets slice, but not buckets' contents. Only take // the write lock if you are modifying the slice or saving for S/R. - mu sync.RWMutex `state:"nosave"` - - // buckets is protected by mu. + // + // +checklocks:mu buckets []bucket } // +stateify savable type bucket struct { - // mu protects tuples. - mu sync.Mutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu tuples tupleList } -// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid -// TCP header. -// -// Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { - netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return tupleID{}, &tcpip.ErrUnknownProtocol{} +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } } - return tupleID{ - srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), - dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), - netProto: pkt.NetworkProtocolNumber, - }, nil + return nil, false } func (ct *ConnTrack) init() { @@ -240,278 +230,285 @@ func (ct *ConnTrack) init() { ct.buckets = make([]bucket, numBuckets) } -// connFor gets the conn for pkt if it exists, or returns nil -// if it does not. It returns an error when pkt does not contain a valid TCP -// header. -// TODO(gvisor.dev/issue/6168): Support UDP. -func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil, dirOriginal +func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { + netHeader := pkt.Network() + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return nil + } + + tid := tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: transportHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, + netProto: pkt.NetworkProtocolNumber, } - return ct.connForTID(tid) -} -func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { - bucket := ct.bucket(tid) - now := time.Now() + bktID := ct.bucket(tid) ct.mu.RLock() - defer ct.mu.RUnlock() - ct.buckets[bucket].mu.Lock() - defer ct.buckets[bucket].mu.Unlock() - - // Iterate over the tuples in a bucket, cleaning up any unused - // connections we find. - for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { - // Clean up any timed-out connections we happen to find. - if ct.reapTupleLocked(other, bucket, now) { - // The tuple expired. - continue - } - if tid == other.tupleID { - return other.conn, other.direction - } + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + now := ct.clock.NowMonotonic() + if t := bkt.connForTID(tid, now); t != nil { + return t } - return nil, dirOriginal -} + bkt.mu.Lock() + defer bkt.mu.Unlock() -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil + // Make sure a connection wasn't added between when we last checked the + // bucket and acquired the bucket's write lock. + if t := bkt.connForTIDRLocked(tid, now); t != nil { + return t } - if hook != Prerouting && hook != Output { - return nil + + // This is the first packet we're seeing for the connection. Create an entry + // for this new connection. + conn := &conn{ + ct: ct, + original: tuple{tupleID: tid}, + reply: tuple{tupleID: tid.reply(), reply: true}, + lastUsed: now, } + conn.original.conn = conn + conn.reply.conn = conn - replyTID := tid.reply() - replyTID.srcAddr = address - replyTID.srcPort = port + // For now, we only map an entry for the packet's original tuple as NAT may be + // performed on this connection. Until the packet goes through all the hooks + // and its final address/port is known, we cannot know what the response + // packet's addresses/ports will look like. + // + // This is okay because the destination cannot send its response until it + // receives the packet; the packet will only be received once all the hooks + // have been performed. + // + // See (*conn).finalize. + bkt.tuples.PushFront(&conn.original) + return &conn.original +} - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil - } - conn = newConn(tid, replyTID, manipDestination, hook) - ct.insertConn(conn) - return conn +func (ct *ConnTrack) connForTID(tid tupleID) *tuple { + bktID := ct.bucket(tid) + + ct.mu.RLock() + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + return bkt.connForTID(tid, ct.clock.NowMonotonic()) } -func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil - } - if hook != Input && hook != Postrouting { - return nil +func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple { + bkt.mu.RLock() + defer bkt.mu.RUnlock() + return bkt.connForTIDRLocked(tid, now) +} + +// +checklocksread:bkt.mu +func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple { + for other := bkt.tuples.Front(); other != nil; other = other.Next() { + if tid == other.id() && !other.conn.timedOut(now) { + return other + } } + return nil +} - replyTID := tid.reply() - replyTID.dstAddr = address - replyTID.dstPort = port +func (ct *ConnTrack) finalize(cn *conn) { + tid := cn.reply.id() + id := ct.bucket(tid) - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil + ct.mu.RLock() + bkt := &ct.buckets[id] + ct.mu.RUnlock() + + bkt.mu.Lock() + defer bkt.mu.Unlock() + + if t := bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic()); t != nil { + // Another connection for the reply already exists. We can't do much about + // this so we leave the connection cn represents in a state where it can + // send packets but its responses will be mapped to some other connection. + // This may be okay if the connection only expects to send packets without + // any responses. + return } - conn = newConn(tid, replyTID, manipSource, hook) - ct.insertConn(conn) - return conn + + bkt.tuples.PushFront(&cn.reply) } -// insertConn inserts conn into the appropriate table bucket. -func (ct *ConnTrack) insertConn(conn *conn) { - // Lock the buckets in the correct order. - tupleBucket := ct.bucket(conn.original.tupleID) - replyBucket := ct.bucket(conn.reply.tupleID) - ct.mu.RLock() - defer ct.mu.RUnlock() - if tupleBucket < replyBucket { - ct.buckets[tupleBucket].mu.Lock() - ct.buckets[replyBucket].mu.Lock() - } else if tupleBucket > replyBucket { - ct.buckets[replyBucket].mu.Lock() - ct.buckets[tupleBucket].mu.Lock() - } else { - // Both tuples are in the same bucket. - ct.buckets[tupleBucket].mu.Lock() - } - - // Now that we hold the locks, ensure the tuple hasn't been inserted by - // another thread. - // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? - alreadyInserted := false - for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { - if other.tupleID == conn.original.tupleID { - alreadyInserted = true - break +func (cn *conn) finalize() { + { + cn.mu.RLock() + finalized := cn.finalized + cn.mu.RUnlock() + if finalized { + return } } - if !alreadyInserted { - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + cn.mu.Lock() + finalized := cn.finalized + cn.finalized = true + cn.mu.Unlock() + if finalized { + return } - // Unlocking can happen in any order. - ct.buckets[tupleBucket].mu.Unlock() - if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce - } + cn.ct.finalize(cn) } -// handlePacket will manipulate the port and address of the packet if the -// connection exists. Returns whether, after the packet traverses the tables, -// it should create a new entry in the table. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { - if pkt.NatDone { - return false +// performNAT setups up the connection for the specified NAT. +// +// Generally, only the first packet of a connection reaches this method; other +// other packets will be manipulated without needing to modify the connection. +func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) { + cn.performNATIfNoop(port, address, dnat) + cn.handlePacket(pkt, hook, r) +} + +func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + if cn.finalized { + return } - switch hook { - case Prerouting, Input, Output, Postrouting: - default: - return false + if dnat { + if cn.destinationManip { + return + } + cn.destinationManip = true + } else { + if cn.sourceManip { + return + } + cn.sourceManip = true } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + cn.reply.mu.Lock() + defer cn.reply.mu.Unlock() + + if dnat { + cn.reply.tupleID.srcAddr = address + cn.reply.tupleID.srcPort = port + } else { + cn.reply.tupleID.dstAddr = address + cn.reply.tupleID.dstPort = port + } +} + +// handlePacket attempts to handle a packet and perform NAT if the connection +// has had NAT performed on it. +// +// Returns true if the packet can skip the NAT table. +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return false } - conn, dir := ct.connFor(pkt) - // Connection not found for the packet. - if conn == nil { - // If this is the last hook in the data path for this packet (Input if - // incoming, Postrouting if outgoing), indicate that a connection should be - // inserted by the end of this hook. - return hook == Input || hook == Postrouting + fullChecksum := false + updatePseudoHeader := false + natDone := &pkt.SNATDone + dnat := false + switch hook { + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + + natDone = &pkt.DNATDone + dnat = true + case Input: + case Forward: + panic("should not handle packet in the forwarding hook") + case Output: + natDone = &pkt.DNATDone + dnat = true + fallthrough + case Postrouting: + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + updatePseudoHeader = true + } else if rt.RequiresTXTransportChecksum() { + fullChecksum = true + updatePseudoHeader = true + } + default: + panic(fmt.Sprintf("unrecognized hook = %d", hook)) } - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false + if *natDone { + panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt)) } // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. - var newAddr tcpip.Address - var newPort uint16 - - updateSRCFields := false - - switch hook { - case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true + reply := pkt.tuple.reply + tid, performManip := func() (tupleID, bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = cn.ct.clock.NowMonotonic() + // Update connection state. + cn.updateLocked(pkt, reply) + + var tuple *tuple + if reply { + if dnat { + if !cn.sourceManip { + return tupleID{}, false + } + } else if !cn.destinationManip { + return tupleID{}, false } - pkt.NatDone = true - } - case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr + + tuple = &cn.original + } else { + if dnat { + if !cn.destinationManip { + return tupleID{}, false + } + } else if !cn.sourceManip { + return tupleID{}, false } - pkt.NatDone = true + + tuple = &cn.reply } - default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) - } - if !pkt.NatDone { + + return tuple.id(), true + }() + if !performManip { return false } - fullChecksum := false - updatePseudoHeader := false - switch hook { - case Prerouting, Input: - case Output, Postrouting: - // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - updatePseudoHeader = true - } else if r.RequiresTXTransportChecksum() { - fullChecksum = true - updatePseudoHeader = true - } - default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) + newPort := tid.dstPort + newAddr := tid.dstAddr + if dnat { + newPort = tid.srcPort + newAddr = tid.srcAddr } rewritePacket( - netHeader, - tcpHeader, - updateSRCFields, + pkt.Network(), + transportHeader, + !dnat, fullChecksum, updatePseudoHeader, newPort, newAddr, ) - // Update the state of tcb. - conn.mu.Lock() - defer conn.mu.Unlock() - - // Mark the connection as having been used recently so it isn't reaped. - conn.lastUsed = time.Now() - // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) - - return false -} - -// maybeInsertNoop tries to insert a no-op connection entry to keep connections -// from getting clobbered when replies arrive. It only inserts if there isn't -// already a connection for pkt. -// -// This should be called after traversing iptables rules only, to ensure that -// pkt.NatDone is set correctly. -func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { - // If there were a rule applying to this packet, it would be marked - // with NatDone. - if pkt.NatDone { - return - } - - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { - return - } - - // This is the first packet we're seeing for the TCP connection. Insert - // the noop entry (an identity mapping) so that the response doesn't - // get NATed, breaking the connection. - tid, err := packetToTupleID(pkt) - if err != nil { - return - } - conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) - ct.insertConn(conn) + *natDone = true + return true } // bucket gets the conntrack bucket for a tupleID. @@ -555,7 +552,7 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim const minInterval = 10 * time.Millisecond const maxInterval = maxFullTraversal / fractionPerReaping - now := time.Now() + now := ct.clock.NowMonotonic() checked := 0 expired := 0 var idx int @@ -563,14 +560,20 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim defer ct.mu.RUnlock() for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { idx = (i + start) % len(ct.buckets) - ct.buckets[idx].mu.Lock() - for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + bkt := &ct.buckets[idx] + bkt.mu.Lock() + for tuple := bkt.tuples.Front(); tuple != nil; { + // reapTupleLocked updates tuple's next pointer so we grab it here. + nextTuple := tuple.Next() + checked++ - if ct.reapTupleLocked(tuple, idx, now) { + if ct.reapTupleLocked(tuple, idx, bkt, now) { expired++ } + + tuple = nextTuple } - ct.buckets[idx].mu.Unlock() + bkt.mu.Unlock() } // We already checked buckets[idx]. idx++ @@ -595,44 +598,51 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // reapTupleLocked tries to remove tuple and its reply from the table. It // returns whether the tuple's connection has timed out. // -// Preconditions: -// * ct.mu is locked for reading. -// * bucket is locked. -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { +// Precondition: ct.mu is read locked and bkt.mu is write locked. +// +checklocksread:ct.mu +// +checklocks:bkt.mu +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { if !tuple.conn.timedOut(now) { return false } - // To maintain lock order, we can only reap these tuples if the reply - // appears later in the table. - replyBucket := ct.bucket(tuple.reply()) - if bucket > replyBucket { + // To maintain lock order, we can only reap both tuples if the reply appears + // later in the table. + replyBktID := ct.bucket(tuple.id().reply()) + tuple.conn.mu.RLock() + replyTupleInserted := tuple.conn.finalized + tuple.conn.mu.RUnlock() + if bktID > replyBktID && replyTupleInserted { return true } - // Don't re-lock if both tuples are in the same bucket. - differentBuckets := bucket != replyBucket - if differentBuckets { - ct.buckets[replyBucket].mu.Lock() + // Reap the reply. + if replyTupleInserted { + // Don't re-lock if both tuples are in the same bucket. + if bktID != replyBktID { + replyBkt := &ct.buckets[replyBktID] + replyBkt.mu.Lock() + removeConnFromBucket(replyBkt, tuple) + replyBkt.mu.Unlock() + } else { + removeConnFromBucket(bkt, tuple) + } } - // We have the buckets locked and can remove both tuples. - if tuple.direction == dirOriginal { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) - } else { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) - } - ct.buckets[bucket].tuples.Remove(tuple) + bkt.tuples.Remove(tuple) + return true +} - // Don't re-unlock if both tuples are in the same bucket. - if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce +// +checklocks:b.mu +func removeConnFromBucket(b *bucket, tuple *tuple) { + if tuple.reply { + b.tuples.Remove(&tuple.conn.original) + } else { + b.tuples.Remove(&tuple.conn.reply) } - - return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,17 +650,22 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } - conn, _ := ct.connForTID(tid) - if conn == nil { + t := ct.connForTID(tid) + if t == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip != manipDestination { + } + + t.conn.mu.RLock() + defer t.conn.mu.RUnlock() + if !t.conn.destinationManip { // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } - return conn.original.dstAddr, conn.original.dstPort, nil + id := t.conn.original.id() + return id.dstAddr, id.dstPort, nil } diff --git a/pkg/tcpip/stack/conntrack_test.go b/pkg/tcpip/stack/conntrack_test.go new file mode 100644 index 000000000..fb0645ed1 --- /dev/null +++ b/pkg/tcpip/stack/conntrack_test.go @@ -0,0 +1,132 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" +) + +func TestReap(t *testing.T) { + // Initialize conntrack. + clock := faketime.NewManualClock() + ct := ConnTrack{ + clock: clock, + } + ct.init() + ct.checkNumTuples(t, 0) + + // Simulate sending a SYN. This will get the connection into conntrack, but + // the connection won't be considered established. Thus the timeout for + // reaping is unestablishedTimeout. + pkt1 := genTCPPacket() + pkt1.tuple = ct.getConnOrMaybeInsertNoop(pkt1) + // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls + // rt.RequiresTXTransportChecksum. + var rt Route + rt.routeInfo.Loop = PacketLoop + if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) { + t.Fatal("handlePacket() shouldn't perform any NAT") + } + ct.checkNumTuples(t, 1) + + // Travel a little into the future and send the same SYN. This should update + // lastUsed, but per #6748 didn't. + clock.Advance(unestablishedTimeout / 2) + pkt2 := genTCPPacket() + pkt2.tuple = ct.getConnOrMaybeInsertNoop(pkt2) + if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) { + t.Fatal("handlePacket() shouldn't perform any NAT") + } + ct.checkNumTuples(t, 1) + + // Travel farther into the future - enough that failing to update lastUsed + // would cause a reaping - and reap the whole table. Make sure the connection + // hasn't been reaped. + clock.Advance(unestablishedTimeout * 3 / 4) + ct.reapEverything() + ct.checkNumTuples(t, 1) + + // Travel past unestablishedTimeout to confirm the tuple is gone. + clock.Advance(unestablishedTimeout / 2) + ct.reapEverything() + ct.checkNumTuples(t, 0) +} + +// genTCPPacket returns an initialized IPv4 TCP packet. +func genTCPPacket() *PacketBuffer { + const packetLen = header.IPv4MinimumSize + header.TCPMinimumSize + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: packetLen, + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.TransportProtocolNumber = header.TCPProtocolNumber + tcpHdr := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize)) + tcpHdr.Encode(&header.TCPFields{ + SrcPort: 5555, + DstPort: 6666, + SeqNum: 7777, + AckNum: 8888, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + Checksum: 0, // Conntrack doesn't verify the checksum. + }) + ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: packetLen, + Protocol: uint8(header.TCPProtocolNumber), + SrcAddr: testutil.MustParse4("1.0.0.1"), + DstAddr: testutil.MustParse4("1.0.0.2"), + Checksum: 0, // Conntrack doesn't verify the checksum. + }) + + return pkt +} + +// checkNumTuples checks that there are exactly want tuples tracked by +// conntrack. +func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) { + t.Helper() + ct.mu.RLock() + defer ct.mu.RUnlock() + + var total int + for idx := range ct.buckets { + ct.buckets[idx].mu.RLock() + total += ct.buckets[idx].tuples.Len() + ct.buckets[idx].mu.RUnlock() + } + + if total != want { + t.Fatalf("checkNumTuples: got %d, wanted %d", total, want) + } +} + +func (ct *ConnTrack) reapEverything() { + var bucket int + for { + newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */) + // We started reaping at bucket 0. If the next bucket isn't after our + // current bucket, we've gone through them all. + if newBucket <= bucket { + break + } + bucket = newBucket + } +} diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index ccb69393b..c2f1f4798 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } @@ -384,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } // NIC 2 has the link address "b", and added the network address 2. @@ -397,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } nic, ok := s.nics[2] diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go index 3a20839da..99e5d2df7 100644 --- a/pkg/tcpip/stack/icmp_rate_limit.go +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -16,6 +16,7 @@ package stack import ( "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/tcpip" ) const ( @@ -31,11 +32,41 @@ const ( // ICMPRateLimiter is a global rate limiter that controls the generation of // ICMP messages generated by the stack. type ICMPRateLimiter struct { - *rate.Limiter + limiter *rate.Limiter + clock tcpip.Clock } // NewICMPRateLimiter returns a global rate limiter for controlling the rate -// at which ICMP messages are generated by the stack. -func NewICMPRateLimiter() *ICMPRateLimiter { - return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} +// at which ICMP messages are generated by the stack. The returned limiter +// does not apply limits to any ICMP types by default. +func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter { + return &ICMPRateLimiter{ + clock: clock, + limiter: rate.NewLimiter(icmpLimit, icmpBurst), + } +} + +// SetLimit sets a new Limit for the limiter. +func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) { + l.limiter.SetLimitAt(l.clock.Now(), limit) +} + +// Limit returns the maximum overall event rate. +func (l *ICMPRateLimiter) Limit() rate.Limit { + return l.limiter.Limit() +} + +// SetBurst sets a new burst size for the limiter. +func (l *ICMPRateLimiter) SetBurst(burst int) { + l.limiter.SetBurstAt(l.clock.Now(), burst) +} + +// Burst returns the maximum burst size. +func (l *ICMPRateLimiter) Burst() int { + return l.limiter.Burst() +} + +// Allow reports whether one ICMP message may be sent now. +func (l *ICMPRateLimiter) Allow() bool { + return l.limiter.AllowN(l.clock.Now(), 1) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f152c0d83..fd61387bf 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables(seed uint32) *IPTables { +func DefaultTables(seed uint32, clock tcpip.Clock) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,8 @@ func DefaultTables(seed uint32) *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: seed, + seed: seed, + clock: clock, }, reaperDone: make(chan struct{}, 1), } @@ -264,33 +265,125 @@ const ( chainReturn ) -// Check runs pkt through the rules for hook. It returns true when the packet -// should continue traversing the network stack and false when it should be -// dropped. +// CheckPrerouting performs the prerouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool { + const hook = Prerouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt) + + return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) +} + +// CheckInput performs the input hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { + const hook = Input + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret +} + +// CheckForward performs the forward hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool { + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) +} + +// CheckOutput performs the output hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool { + const hook = Output + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt) + + return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// CheckPostrouting performs the postrouting hook on the packet. // -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { - if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool { + const hook = Postrouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { return true } + + ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret +} + +func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool { + switch netProto { + case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber: + default: + // IPTables only supports IPv4/IPv6. + return true + } + + it.mu.RLock() + defer it.mu.RUnlock() // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. + return !it.modified +} + +// check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool { it.mu.RLock() defer it.mu.RUnlock() - if !it.modified { - return true - } - - // Packets are manipulated only if connection and matching - // NAT rule exists. - shouldTrack := it.connections.handlePacket(pkt, hook, r) // Go through each table containing the hook. priorities := it.priorities[hook] for _, tableID := range priorities { - // If handlePacket already NATed the packet, we don't need to - // check the NAT table. - if tableID == NATID && pkt.NatDone { + if t := pkt.tuple; t != nil && tableID == NATID && t.conn.handlePacket(pkt, hook, r) { continue } var table Table @@ -300,7 +393,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -311,7 +404,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -327,21 +420,6 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr } } - // If this connection should be tracked, try to add an entry for it. If - // traversing the nat table didn't end in adding an entry, - // maybeInsertNoop will add a no-op entry for the connection. This is - // needeed when establishing connections so that the SYN/ACK reply to an - // outgoing SYN is delivered to the correct endpoint rather than being - // redirected by a prerouting rule. - // - // From the iptables documentation: "If there is no rule, a `null' - // binding is created: this usually does not map the packet, but exists - // to ensure we don't map another stream over an existing one." - if shouldTrack { - it.connections.maybeInsertNoop(pkt, hook) - } - - // Every table returned Accept. return true } @@ -375,30 +453,46 @@ func (it *IPTables) startReaper(interval time.Duration) { }() } -// CheckPackets runs pkts through the rules for hook and returns a map of packets that -// should not go forward. +// CheckOutputPackets performs the output hook on the packets. // -// Preconditions: -// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// * pkt.NetworkHeader is not nil. +// Returns a map of packets that must be dropped. // -// NOTE: unlike the Check API the returned map contains packets that should be -// dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckOutput(pkt, r, outNicName) + }, true /* dnat */) +} + +// CheckPostroutingPackets performs the postrouting hook on the packets. +// +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckPostrouting(pkt, r, addressEP, outNicName) + }, false /* dnat */) +} + +func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool, dnat bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if !pkt.NatDone { - if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { - if drop == nil { - drop = make(map[*PacketBuffer]struct{}) - } - drop[pkt] = struct{}{} + natDone := &pkt.SNATDone + if dnat { + natDone = &pkt.DNATDone + } + + if ok := f(pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) } - if pkt.NatDone { - if natPkts == nil { - natPkts = make(map[*PacketBuffer]struct{}) - } - natPkts[pkt] = struct{}{} + drop[pkt] = struct{}{} + } + if *natDone { + if natPkts == nil { + natPkts = make(map[*PacketBuffer]struct{}) } + natPkts[pkt] = struct{}{} } } return drop, natPkts @@ -407,11 +501,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -428,7 +522,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -454,7 +548,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -477,16 +571,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) + return rule.Target.Action(pkt, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, &tcpip.ErrNotConnected{} } - return it.connections.originalDst(epID, netProto) + return it.connections.originalDst(epID, netProto, transProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 96cc899bb..ef515bdd2 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,10 +79,49 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } +// DNATTarget modifies the destination port/IP of packets. +type DNATTarget struct { + // The new destination address for packets. + // + // Immutable. + Addr tcpip.Address + + // The new destination port for packets. + // + // Immutable. + Port uint16 + + // NetworkProtocol is the network protocol the target is used with. + // + // Immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Prerouting, Output: + case Input, Forward, Postrouting: + panic(fmt.Sprintf("%s not supported for DNAT", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + return natAction(pkt, hook, r, rt.Port, rt.Addr, true /* dnat */) + +} + // RedirectTarget redirects the packet to this machine by modifying the // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than @@ -97,7 +136,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -105,18 +144,9 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r rt.NetworkProtocol, pkt.NetworkProtocolNumber)) } - // Packet is already manipulated. - if pkt.NatDone { - return RuleAccept, 0 - } - - // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { - return RuleDrop, 0 - } - // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. + var address tcpip.Address switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { @@ -125,48 +155,13 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r address = header.IPv6Loopback } case Prerouting: - // No-op, as address is already set correctly. + // addressEP is expected to be set for the prerouting hook. + address = addressEP.MainAddress().Address default: panic("redirect target is supported only on output and prerouting hooks") } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - - if hook == Output { - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - udpHeader, - false, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - rt.Port, - address, - ) - } else { - udpHeader.SetDestinationPort(rt.Port) - } - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } - - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 - } - - return RuleAccept, 0 + return natAction(pkt, hook, r, rt.Port, address, true /* dnat */) } // SNATTarget modifies the source port/IP in the outgoing packets. @@ -179,8 +174,36 @@ type SNATTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } +func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) { + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { + return RuleDrop, 0 + } + + t := pkt.tuple + if t == nil { + return RuleDrop, 0 + } + + // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a + // different port. + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + case header.TCPProtocolNumber: + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + default: + panic(fmt.Sprintf("unsupported transport protocol = %d", pkt.TransportProtocolNumber)) + } + } + + t.conn.performNAT(pkt, hook, r, port, address, dnat) + return RuleAccept, 0 +} + // Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if st.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -188,16 +211,6 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou st.NetworkProtocol, pkt.NetworkProtocolNumber)) } - // Packet is already manipulated. - if pkt.NatDone { - return RuleAccept, 0 - } - - // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { - return RuleDrop, 0 - } - switch hook { case Postrouting, Input: case Prerouting, Output, Forward: @@ -206,37 +219,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - header.UDP(pkt.TransportHeader().View()), - true, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - st.Port, - st.Addr, - ) - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } + return natAction(pkt, hook, r, st.Port, st.Addr, false /* dnat */) +} - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) - } +// MasqueradeTarget modifies the source port/IP in the outgoing packets. +type MasqueradeTarget struct { + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if mt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + mt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Postrouting: + case Prerouting, Input, Forward, Output: + panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook)) default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + // addressEP is expected to be set for the postrouting hook. + ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */) + if ep == nil { + // No address exists that we can use as a source address. return RuleDrop, 0 } - return RuleAccept, 0 + address := ep.AddressWithPrefix().Address + ep.DecRef() + return natAction(pkt, hook, r, 0 /* port */, address, false /* dnat */) } func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 66e5f22ac..b22024667 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -81,17 +81,6 @@ const ( // // +stateify savable type IPTables struct { - // mu protects v4Tables, v6Tables, and modified. - mu sync.RWMutex - // v4Tables and v6tables map tableIDs to tables. They hold builtin - // tables only, not user tables. mu must be locked for accessing. - v4Tables [NumTables]Table - v6Tables [NumTables]Table - // modified is whether tables have been modified at least once. It is - // used to elide the iptables performance overhead for workloads that - // don't utilize iptables. - modified bool - // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. @@ -101,6 +90,21 @@ type IPTables struct { // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} + + mu sync.RWMutex + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. + // + // +checklocks:mu + v4Tables [NumTables]Table + // +checklocks:mu + v6Tables [NumTables]Table + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + // + // +checklocks:mu + modified bool } // VisitTargets traverses all the targets of all tables and replaces each with @@ -352,5 +356,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 4d5431da1..40b33b6b5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Should get the address immediately since we should not have performed @@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + }, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Make sure the address does not resolve before the resolution time has @@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet @@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) { // Attempting to add the address again should not fail if the address's // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } }) } @@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) { // Add addresses for each NIC. addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix1, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err) } addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix2, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err) } expectDADEvent(nicID2, addr2) addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix3, + } + if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err) } expectDADEvent(nicID3, addr3) @@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { continue } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: test.addrs[j].Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{} @@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr2, } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr} + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err) } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) @@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) { } if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a796942ab..e251e3b24 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -97,6 +97,8 @@ type packetEndpointList struct { mu sync.RWMutex // eps is protected by mu, but the contained PacketEndpoint values are not. + // + // +checklocks:mu eps []PacketEndpoint } @@ -117,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) { } } +func (p *packetEndpointList) len() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.eps) +} + // forEach calls fn with each endpoints in p while holding the read lock on p. func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { p.mu.RLock() @@ -157,14 +165,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 - // Register supported packet and network endpoint protocols. - for _, netProto := range header.Ethertypes { - nic.packetEPs.eps[netProto] = new(packetEndpointList) - } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.packetEPs.eps[netNum] = new(packetEndpointList) - netEP := netProto.NewEndpoint(nic, nic) nic.networkEndpoints[netNum] = netEP @@ -514,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return &tcpip.ErrUnknownProtocol{} @@ -525,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return &tcpip.ErrNotSupported{} } - addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -831,24 +833,9 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt transProto := state.proto - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled - // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - // ICMP packets may be longer, but until icmp.Parse is implemented, here - // we parse it using the minimum size. - if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stats.malformedL4RcvdPackets.Increment() - // We consider a malformed transport packet handled because there is - // nothing the caller can do. - return TransportPacketHandled - } - } else if !transProto.Parse(pkt) { - n.stats.malformedL4RcvdPackets.Increment() - return TransportPacketHandled - } + n.stats.malformedL4RcvdPackets.Increment() + return TransportPacketHandled } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) @@ -974,7 +961,8 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa eps, ok := n.packetEPs.eps[netProto] if !ok { - return &tcpip.ErrNotSupported{} + eps = new(packetEndpointList) + n.packetEPs.eps[netProto] = eps } eps.add(ep) @@ -990,6 +978,9 @@ func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep return } eps.remove(ep) + if eps.len() == 0 { + delete(n.packetEPs.eps, netProto) + } } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5cb342f78..c8ad93f29 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements NetworkProtocol.ParseAddresses. func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 29c22bfd4..c4a4bbd22 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -126,9 +126,13 @@ type PacketBuffer struct { EgressRoute RouteInfo GSOOptions GSO - // NatDone indicates if the packet has been manipulated as per NAT - // iptables rule. - NatDone bool + // SNATDone indicates if the packet's source has been manipulated as per + // iptables NAT table. + SNATDone bool + + // DNATDone indicates if the packet's destination has been manipulated as per + // iptables NAT table. + DNATDone bool // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. @@ -143,6 +147,8 @@ type PacketBuffer struct { // NetworkPacketInfo holds an incoming packet's network-layer information. NetworkPacketInfo NetworkPacketInfo + + tuple *tuple } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -296,12 +302,14 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { Owner: pk.Owner, GSOOptions: pk.GSOOptions, NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, + DNATDone: pk.DNATDone, + SNATDone: pk.SNATDone, TransportProtocolNumber: pk.TransportProtocolNumber, PktType: pk.PktType, NICID: pk.NICID, RXTransportChecksumValidated: pk.RXTransportChecksumValidated, NetworkPacketInfo: pk.NetworkPacketInfo, + tuple: pk.tuple, } } @@ -329,15 +337,41 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { buf: pk.buf.Clone(), // Treat unfilled header portion as reserved. reserved: pk.AvailableHeaderBytes(), + tuple: pk.tuple, + } + return newPk +} + +// DeepCopyForForwarding creates a deep copy of the packet buffer for +// forwarding. +// +// The returned packet buffer will have the network and transport headers +// set if the original packet buffer did. +func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer { + newPk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: reservedHeaderBytes, + Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(), + IsForwardedPacket: true, + }) + + { + consumeBytes := pk.NetworkHeader().View().Size() + if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed { + panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes)) + } + newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber } - // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to - // maintain this flag in the packet. Currently conntrack needs this flag to - // tell if a noop connection should be inserted at Input hook. Once conntrack - // redefines the manipulation field as mutable, we won't need the special noop - // connection. - if pk.NatDone { - newPk.NatDone = true + + { + consumeBytes := pk.TransportHeader().View().Size() + if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed { + panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes)) + } + newPk.TransportProtocolNumber = pk.TransportProtocolNumber } + + newPk.tuple = pk.tuple + return newPk } @@ -389,13 +423,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// DeleteFront removes count from the beginning of d. It panics if count > -// d.Size(). All backing storage references after the front of the d are -// invalidated. -func (d PacketData) DeleteFront(count int) { - if !d.pk.buf.Remove(d.pk.dataOffset(), count) { - panic("count > d.Size()") +// Consume is the same as PullUp except that is additionally consumes the +// returned bytes. Subsequent PullUp or Consume will not return these bytes. +func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) { + v, ok := d.PullUp(size) + if ok { + d.pk.consumed += size } + return v, ok } // CapLength reduces d to at most length bytes. diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 87b023445..c376ed1a1 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) { } } -func TestPacketBufferClone(t *testing.T) { - data := concatViews(makeView(20), makeView(30), makeView(40)) - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - bytesToDelete := 30 - originalSize := data.Size() - - clonedPks := []*PacketBuffer{ - pk.Clone(), - pk.CloneToInbound(), - } - pk.Data().DeleteFront(bytesToDelete) - if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { - t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) - } - for _, clonedPk := range clonedPks { - if got := clonedPk.Data().Size(); got != originalSize { - t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) - } - } -} - func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string @@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) { } }) - // DeleteFront + // Consume. for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) + v, ok := pkt.Data().Consume(n) + if !ok { + t.Fatalf("Consume failed") + } + if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) { + t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want) + } checkData(t, pkt, []byte(tc.data)[n:]) }) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 113baaaae..31b3a554d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int const ( // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. + // endpoint for new connections with no local address. CanBePrimaryEndpoint PrimaryEndpointBehavior = iota // FirstPrimaryEndpoint indicates the endpoint should be the first @@ -332,6 +331,19 @@ const ( NeverPrimaryEndpoint ) +func (peb PrimaryEndpointBehavior) String() string { + switch peb { + case CanBePrimaryEndpoint: + return "CanBePrimaryEndpoint" + case FirstPrimaryEndpoint: + return "FirstPrimaryEndpoint" + case NeverPrimaryEndpoint: + return "NeverPrimaryEndpoint" + default: + panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb)) + } +} + // AddressConfigType is the method used to add an address. type AddressConfigType int @@ -351,6 +363,14 @@ const ( AddressConfigSlaacTemp ) +// AddressProperties contains additional properties that can be configured when +// adding an address. +type AddressProperties struct { + PEB PrimaryEndpointBehavior + ConfigType AddressConfigType + Deprecated bool +} + // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { @@ -457,7 +477,7 @@ type AddressableEndpoint interface { // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. @@ -685,9 +705,6 @@ type NetworkProtocol interface { // than this targeted at this protocol. MinimumPacketSize() int - // DefaultPrefixLen returns the protocol's default prefix length. - DefaultPrefixLen() int - // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cb741e540..a05fd7036 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -238,7 +238,7 @@ type Options struct { // DefaultIPTables is an optional iptables rules constructor that is called // if IPTables is nil. If both fields are nil, iptables will allow all // traffic. - DefaultIPTables func(uint32) *IPTables + DefaultIPTables func(seed uint32, clock tcpip.Clock) *IPTables // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader @@ -358,7 +358,7 @@ func New(opts Options) *Stack { if opts.DefaultIPTables == nil { opts.DefaultIPTables = DefaultTables } - opts.IPTables = opts.DefaultIPTables(seed) + opts.IPTables = opts.DefaultIPTables(seed, clock) } opts.NUDConfigs.resetInvalidFields() @@ -375,7 +375,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), + icmpRateLimiter: NewICMPRateLimiter(clock), seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, @@ -916,46 +916,9 @@ type NICStateFlags struct { Loopback bool } -// AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { - return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) -} - -// AddAddressWithPrefix is the same as AddAddress, but allows you to specify -// the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { - ap := tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: addr, - } - return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) -} - -// AddProtocolAddress adds a new network-layer protocol address to the -// specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { - return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) -} - -// AddAddressWithOptions is the same as AddAddress, but allows you to specify -// whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return &tcpip.ErrUnknownProtocol{} - } - return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb) -} - -// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows -// you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +// AddProtocolAddress adds an address to the specified NIC, possibly with extra +// properties. +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -964,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return &tcpip.ErrUnknownNICID{} } - return nic.addAddress(protocolAddress, peb) + return nic.addAddress(protocolAddress, properties) } // RemoveAddress removes an existing network-layer address from the specified @@ -1902,12 +1865,6 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - return ParsedOK - } - pkt.TransportProtocolNumber = protocol // Parse the transport header if present. state, ok := s.transportProtocols[protocol] diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3089c0ef4..f5a35eac4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().Consume(fakeNetHeaderLen) if !ok { return } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), // Nothing checks the error. nil, /* transport error */ pkt, @@ -158,8 +155,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + transProtoNum := tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]) + switch err := f.proto.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(transProtoNum, pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -221,6 +228,8 @@ func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. type fakeNetworkProtocol struct { + stack *stack.Stack + packetCount [10]int sendPacketCount [10]int defaultTTL uint8 @@ -234,10 +243,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { return f.packetCount[int(intfAddr)%len(f.packetCount)] } @@ -306,8 +311,8 @@ func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.forwarding = v } -func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { - return &fakeNetworkProtocol{} +func fakeNetFactory(s *stack.Stack) stack.NetworkProtocol { + return &fakeNetworkProtocol{stack: s} } // linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify @@ -349,12 +354,26 @@ func TestNetworkReceive(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -517,8 +536,15 @@ func TestNetworkSend(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Make sure that the link-layer endpoint received the outbound packet. @@ -538,12 +564,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -551,12 +591,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -812,8 +866,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } ep2 := channel.New(1, defaultMTU, "") @@ -821,8 +882,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr2, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } // Set a route table that sends all packets with odd destination @@ -978,12 +1046,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -991,12 +1073,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -1058,8 +1154,15 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1108,8 +1211,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1242,8 +1352,15 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1270,8 +1387,8 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1310,8 +1427,8 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1453,8 +1570,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) @@ -1510,8 +1634,15 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { @@ -1633,8 +1764,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) + if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { @@ -1678,13 +1809,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { t.Fatalf("CreateNIC failed: %s", err) } nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) + if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) + if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err) } // Set the initial route table. @@ -1726,7 +1857,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1}, }, rt..., ) @@ -1808,8 +1939,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) } - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: anyAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { @@ -1886,22 +2024,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Add an address and in case of a primary one include a // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) + properties := stack.AddressProperties{PEB: behavior} if behavior == stack.CanBePrimaryEndpoint { protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, + Protocol: fakeNetNumber, + AddressWithPrefix: address.WithPrefix(), } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } // Remember the address/prefix. primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } } } @@ -1996,8 +2139,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { PrefixLen: tc.prefixLen, }, } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) + if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -2047,33 +2190,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } } -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ @@ -2084,96 +2200,43 @@ func TestAddProtocolAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - addrLenRange := []int{4, 16} behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) + configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp} + deprecatedRange := []bool{false, true} + wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange)) var addrGen addressGenerator for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) + for _, configType := range configTypeRange { + for _, deprecated := range deprecatedRange { + address := addrGen.next(addrLen) + properties := stack.AddressProperties{ + PEB: behavior, + ConfigType: configType, + Deprecated: deprecated, + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err) + } + wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, + }) } - expectedAddresses = append(expectedAddresses, protocolAddress) } } } gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) + verifyAddresses(t, wantAddresses, gotAddresses) } func TestCreateNICWithOptions(t *testing.T) { @@ -2290,8 +2353,15 @@ func TestNICStats(t *testing.T) { if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed: ", err) } - if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: nic.addr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err) } { @@ -2735,8 +2805,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // be returned by a call to GetMainNICAddress; // else, it should. const address1 = tcpip.Address("\x01") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + properties := stack.AddressProperties{PEB: pi} + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err) } addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) if err != nil { @@ -2785,16 +2863,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. const address3 = tcpip.Address("\x03") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) - + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address3, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err) } // Add back the address we removed earlier and // make sure the new peb was respected. // (The address should just be promoted now). - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: ps} + if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { @@ -3096,8 +3189,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: a.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -3203,8 +3300,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // The NIC should have joined addr1's solicited node multicast address. @@ -3359,8 +3460,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { PrefixLen: 128, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } // Address should be in the list of all addresses. @@ -3687,8 +3788,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) @@ -3750,8 +3851,8 @@ func TestResolveWith(t *testing.T) { PrefixLen: 24, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) @@ -3792,8 +3893,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -3881,8 +3989,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -3990,44 +4098,44 @@ func TestFindRouteWithForwarding(t *testing.T) { ) type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1AddrWithPrefix tcpip.AddressWithPrefix + nic2AddrWithPrefix tcpip.AddressWithPrefix + remoteAddr tcpip.Address } fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen}, + nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen}, + remoteAddr: remoteAddr, } globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: llAddr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: globalIPv6Addr1, } ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: llAddr1.WithPrefix(), + remoteAddr: llAddr2, } ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", } tests := []struct { @@ -4036,8 +4144,8 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg netCfg forwardingEnabled bool - addrNIC tcpip.NICID - localAddr tcpip.Address + addrNIC tcpip.NICID + localAddrWithPrefix tcpip.AddressWithPrefix findRouteErr tcpip.Error dependentOnForwarding bool @@ -4047,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4056,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4065,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4074,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4083,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4092,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4101,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4110,7 +4218,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4118,7 +4226,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4126,7 +4234,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4134,7 +4242,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4142,7 +4250,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4166,7 +4274,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on different NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4174,7 +4282,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4182,7 +4290,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4190,7 +4298,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4198,7 +4306,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4206,7 +4314,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4214,7 +4322,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4222,7 +4330,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4230,7 +4338,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4238,7 +4346,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4246,7 +4354,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4268,12 +4376,20 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) } - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic1AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic2AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { @@ -4282,20 +4398,20 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { return } - if r.LocalAddress() != test.localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr) + if r.LocalAddress() != test.localAddrWithPrefix.Address { + t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address) } if r.RemoteAddress() != test.netCfg.remoteAddr { t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr) @@ -4318,8 +4434,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if !ok { t.Fatal("packet not sent through ep2") } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address) } if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index dc7289441..a941091b0 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -289,6 +289,12 @@ type TCPSenderState struct { // RACKState holds the state related to RACK loss detection algorithm. RACKState TCPRACKState + + // RetransmitTS records the timestamp used to detect spurious recovery. + RetransmitTS uint32 + + // SpuriousRecovery indicates if the sender entered recovery spuriously. + SpuriousRecovery bool } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 824cf6526..3474c292a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -32,11 +32,13 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { - // mu protects all fields of the transportEndpoints. - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. + // + // +checklocks:mu rawEndpoints []RawTransportEndpoint } @@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { // descending order of match quality. If a call to yield returns false, // iterEndpointsLocked stops iteration and returns immediately. // -// Preconditions: eps.mu must be locked. +// +checklocksread:eps.mu func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in // descending order of match quality. // -// Preconditions: eps.mu must be locked. +// +checklocksread:eps.mu func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { var matchedEPs []*endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) [] // findEndpointLocked returns the endpoint that most closely matches the given id. // -// Preconditions: eps.mu must be locked. +// +checklocksread:eps.mu func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { var matchedEP *endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo } type endpointsByNIC struct { - mu sync.RWMutex - endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 + + mu sync.RWMutex + // +checklocks:mu + endpoints map[tcpip.NICID]*multiPortEndpoint } func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { @@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet return true } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNIC.seed) + transEP := mpep.selectEndpoint(id, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() @@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) + mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` demux *transportDemuxer netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber + flags ports.FlagCounter + + mu sync.RWMutex `state:"nosave"` // endpoints stores the transport endpoints in the order in which they // were bound. This is required for UDP SO_REUSEADDR. + // + // +checklocks:mu endpoints []TransportEndpoint - flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpoints) == 1 { - return mpep.endpoints[0] +func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint { + ep.mu.RLock() + defer ep.mu.RUnlock() + + if len(ep.endpoints) == 1 { + return ep.endpoints[0] } - if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { - return mpep.endpoints[len(mpep.endpoints)-1] + if ep.flags.SharedFlags().ToFlags().Effective().MostRecent { + return ep.endpoints[len(ep.endpoints)-1] } payload := []byte{ @@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) - return mpep.endpoints[idx] + idx := reciprocalScale(hash, uint32(len(ep.endpoints))) + return ep.endpoints[idx] } func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { @@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN } } - ep := selectEndpoint(id, mpep, epsByNIC.seed) + ep := mpep.selectEndpoint(id, epsByNIC.seed) epsByNIC.mu.RUnlock() return ep } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 45b09110d..cd3a8c25a 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -35,7 +35,7 @@ import ( const ( testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") testSrcAddrV4 = "\x0a\x00\x00\x01" testDstAddrV4 = "\x0a\x00\x00\x02" @@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: testDstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 839178809..51870d03f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -331,8 +331,11 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) - return ok + if _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen); ok { + pkt.TransportProtocolNumber = fakeTransNumber + return true + } + return false } func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { @@ -357,8 +360,15 @@ func TestTransportReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -428,8 +438,15 @@ func TestTransportControlReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -497,8 +514,15 @@ func TestTransportSend(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 55683b4fb..460a6afaf 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -19,7 +19,7 @@ // The starting point is the creation and configuration of a stack. A stack can // be created by calling the New() function of the tcpip/stack/stack package; // configuring a stack involves creating NICs (via calls to Stack.CreateNIC()), -// adding network addresses (via calls to Stack.AddAddress()), and +// adding network addresses (via calls to Stack.AddProtocolAddress()), and // setting a route table (via a call to Stack.SetRouteTable()). // // Once a stack is configured, endpoints can be created by calling @@ -423,9 +423,9 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packet used to create - // the read data was received. - Timestamp int64 + // Timestamp is the time that the last packet used to create the read data + // was received. + Timestamp time.Time `state:".(int64)"` // HasInq indicates whether Inq is valid/set. HasInq bool @@ -451,6 +451,12 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set. + HasIPv6PacketInfo bool + + // IPv6PacketInfo holds interface and address data on an incoming packet. + IPv6PacketInfo IPv6PacketInfo + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is // set. HasOriginalDstAddress bool @@ -465,10 +471,10 @@ type ControlMessages struct { // PacketOwner is used to get UID and GID of the packet. type PacketOwner interface { - // UID returns KUID of the packet. + // KUID returns KUID of the packet. KUID() uint32 - // GID returns KGID of the packet. + // KGID returns KGID of the packet. KGID() uint32 } @@ -1164,6 +1170,14 @@ type IPPacketInfo struct { DestinationAddr Address } +// IPv6PacketInfo is the message structure for IPV6_PKTINFO. +// +// +stateify savable +type IPv6PacketInfo struct { + Addr Address + NIC NICID +} + // SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max send buffer sizes. type SendBufferSizeOption struct { @@ -1231,11 +1245,11 @@ type Route struct { // String implements the fmt.Stringer interface. func (r Route) String() string { var out strings.Builder - fmt.Fprintf(&out, "%s", r.Destination) + _, _ = fmt.Fprintf(&out, "%s", r.Destination) if len(r.Gateway) > 0 { - fmt.Fprintf(&out, " via %s", r.Gateway) + _, _ = fmt.Fprintf(&out, " via %s", r.Gateway) } - fmt.Fprintf(&out, " nic %d", r.NIC) + _, _ = fmt.Fprintf(&out, " nic %d", r.NIC) return out.String() } @@ -1255,6 +1269,8 @@ type TransportProtocolNumber uint32 type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. +// +// +stateify savable type StatCounter struct { count atomicbitops.AlignedAtomicUint64 } @@ -1270,7 +1286,7 @@ func (s *StatCounter) Decrement() { } // Value returns the current value of the counter. -func (s *StatCounter) Value(name ...string) uint64 { +func (s *StatCounter) Value(...string) uint64 { return s.count.Load() } @@ -1849,6 +1865,10 @@ type TCPStats struct { // SegmentsAckedWithDSACK is the number of segments acknowledged with // DSACK. SegmentsAckedWithDSACK *StatCounter + + // SpuriousRecovery is the number of times the connection entered loss + // recovery spuriously. + SpuriousRecovery *StatCounter } // UDPStats collects UDP-specific stats. @@ -1981,6 +2001,8 @@ type Stats struct { } // ReceiveErrors collects packet receive errors within transport endpoint. +// +// +stateify savable type ReceiveErrors struct { // ReceiveBufferOverflow is the number of received packets dropped // due to the receive buffer being full. @@ -1998,8 +2020,10 @@ type ReceiveErrors struct { ChecksumErrors StatCounter } -// SendErrors collects packet send errors within the transport layer for -// an endpoint. +// SendErrors collects packet send errors within the transport layer for an +// endpoint. +// +// +stateify savable type SendErrors struct { // SendToNetworkFailed is the number of packets failed to be written to // the network endpoint. @@ -2010,6 +2034,8 @@ type SendErrors struct { } // ReadErrors collects segment read errors from an endpoint read call. +// +// +stateify savable type ReadErrors struct { // ReadClosed is the number of received packet drops because the endpoint // was shutdown for read. @@ -2025,6 +2051,8 @@ type ReadErrors struct { } // WriteErrors collects packet write errors from an endpoint write call. +// +// +stateify savable type WriteErrors struct { // WriteClosed is the number of packet drops because the endpoint // was shutdown for write. @@ -2040,6 +2068,8 @@ type WriteErrors struct { } // TransportEndpointStats collects statistics about the endpoint. +// +// +stateify savable type TransportEndpointStats struct { // PacketsReceived is the number of successful packet receives. PacketsReceived StatCounter diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/tcpip_state.go index 529e02a07..1953e24a1 100644 --- a/pkg/tcpip/stack/iptables_state.go +++ b/pkg/tcpip/tcpip_state.go @@ -1,4 +1,4 @@ -// Copyright 2020 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,29 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +package tcpip import ( "time" ) -// +stateify savable -type unixTime struct { - second int64 - nano int64 +func (c *ControlMessages) saveTimestamp() int64 { + return c.Timestamp.UnixNano() } -// saveLastUsed is invoked by stateify. -func (cn *conn) saveLastUsed() unixTime { - return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} -} - -// loadLastUsed is invoked by stateify. -func (cn *conn) loadLastUsed(unix unixTime) { - cn.lastUsed = time.Unix(unix.second, unix.nano) -} - -// beforeSave is invoked by stateify. -func (ct *ConnTrack) beforeSave() { - ct.mu.Lock() +func (c *ControlMessages) loadTimestamp(nsec int64) { + c.Timestamp = time.Unix(0, nsec) } diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 181ef799e..99f4d4d0e 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -34,12 +34,16 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -139,3 +143,25 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "istio_test", + size = "small", + srcs = ["istio_test.go"], + deps = [ + "//pkg/context", + "//pkg/rand", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/header", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/pipe", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 92fa6257d..6e1d4720d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) { addr: utils.RouterNIC2IPv6Addr, }, } { - if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err) } } diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index f9ab7d0af..957a779bf 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -15,19 +15,24 @@ package iptables_test import ( + "bytes" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type inputIfNameMatcher struct { @@ -49,10 +54,10 @@ const ( nicName = "nic1" anotherNicName = "nic2" linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - srcAddrV4 = "\x0a\x00\x00\x01" - dstAddrV4 = "\x0a\x00\x00\x02" - srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01") + dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02") + srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") payloadSize = 20 ) @@ -66,8 +71,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: dstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -82,8 +91,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: dstAddrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -601,11 +614,19 @@ func TestIPTableWritePackets(t *testing.T) { if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: srcAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) + } + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: srcAddrV4.WithPrefix(), } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } s.SetRouteTable([]tcpip.Route{ @@ -856,11 +877,19 @@ func TestForwardingHook(t *testing.T) { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -1037,22 +1066,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) { if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err) } e2 := channel.New(1, header.IPv6MinimumMTU, "") if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -1132,3 +1161,621 @@ func TestInputHookWithLocalForwarding(t *testing.T) { }) } } + +func TestNAT(t *testing.T) { + const listenPort uint16 = 8080 + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.FullAddress + serverReadableCH chan struct{} + serverConnectAddr tcpip.Address + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + clientConnectAddr tcpip.FullAddress + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + t.Cleanup(ep.Close) + + return ep, ch + } + + setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + table := ipt.GetTable(stack.NATID, ipv6) + ruleIdx := table.BuiltinChains[hook] + table.Rules[ruleIdx].Filter = filter + table.Rules[ruleIdx].Target = target + // Make sure the packet is not dropped by the next rule. + table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } + } + + setupDNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { + t.Helper() + + setupNAT( + t, + s, + netProto, + stack.Prerouting, + stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + target) + } + + setupSNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { + t.Helper() + + setupNAT( + t, + s, + netProto, + stack.Postrouting, + stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + target) + } + + type natType struct { + name string + setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) + } + + snatTypes := []natType{ + { + name: "SNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address) { + t.Helper() + + setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr}) + }, + }, + { + name: "Masquerade", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) { + t.Helper() + + setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto}) + }, + }, + } + dnatTypes := []natType{ + { + name: "Redirect", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) { + t.Helper() + + setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: listenPort}) + }, + }, + { + name: "DNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address) { + t.Helper() + + setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort}) + }, + }, + } + + setupTwiceNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, snatTarget stack.Target) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + + table := stack.Table{ + Rules: []stack.Rule{ + // Prerouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + Target: &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort}, + }, + { + Target: &stack.AcceptTarget{}, + }, + + // Input + { + Target: &stack.AcceptTarget{}, + }, + + // Forward + { + Target: &stack.AcceptTarget{}, + }, + + // Output + { + Target: &stack.AcceptTarget{}, + }, + + // Postrouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + Target: snatTarget, + }, + { + Target: &stack.AcceptTarget{}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 2, + stack.Forward: 3, + stack.Output: 4, + stack.Postrouting: 5, + }, + } + + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } + } + twiceNATTypes := []natType{ + { + name: "DNAT-Masquerade", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) { + t.Helper() + + setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.MasqueradeTarget{NetworkProtocol: netProto}) + }, + }, + { + name: "DNAT-SNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) { + t.Helper() + + setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr}) + }, + }, + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + // Setups up the stacks in such a way that: + // + // - Host2 is the client for all tests. + // - When performing SNAT only: + // + Host1 is the server. + // + NAT will transform client-originating packets' source addresses to + // the router's NIC1's address before reaching Host1. + // - When performing DNAT only: + // + Router is the server. + // + Client will send packets directed to Host1. + // + NAT will transform client-originating packets' destination addresses + // to the router's NIC2's address. + // - When performing Twice-NAT: + // + Host1 is the server. + // + Client will send packets directed to router's NIC2. + // + NAT will transform client originating packets' destination addresses + // to Host1's address. + // + NAT will transform client-originating packets' source addresses to + // the router's NIC1's address before reaching Host1. + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses + natTypes []natType + }{ + { + name: "IPv4 SNAT", + netProto: ipv4.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: snatTypes, + }, + { + name: "IPv4 DNAT", + netProto: ipv4.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + // If we are performing DNAT, then the packet will be redirected + // to the router. + listenerStack := routerStack + serverAddr := tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address + // DNAT will update the destination port to what the server is + // bound to. + clientConnectPort := serverAddr.Port + 1 + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: dnatTypes, + }, + { + name: "IPv4 Twice-NAT", + netProto: ipv4.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: twiceNATTypes, + }, + { + name: "IPv6 SNAT", + netProto: ipv6.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: snatTypes, + }, + { + name: "IPv6 DNAT", + netProto: ipv6.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + // If we are performing DNAT, then the packet will be redirected + // to the router. + listenerStack := routerStack + serverAddr := tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address + // DNAT will update the destination port to what the server is + // bound to. + clientConnectPort := serverAddr.Port + 1 + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: dnatTypes, + }, + { + name: "IPv6 Twice-NAT", + netProto: ipv6.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: twiceNATTypes, + }, + } + + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr tcpip.Error + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) + } + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: &tcpip.ErrConnectStarted{}, + setupServer: func(t *testing.T, ep tcpip.Endpoint) { + t.Helper() + + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) + } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + return newEP, newCH + } + }, + needRemoteAddr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + for _, natType := range test.natTypes { + t.Run(natType.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr) + + if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } + { + err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff) + } + } + serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr} + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + serverConnectAddr.Port = addr.Port + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + var r bytes.Reader + r.Reset(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(&r, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + var buf bytes.Buffer + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + { + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + read(serverCH, serverEP, data, serverConnectAddr) + } + + { + data := []byte{5, 6, 7, 8, 9, 10, 11, 12} + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go new file mode 100644 index 000000000..95d994ef8 --- /dev/null +++ b/pkg/tcpip/tests/integration/istio_test.go @@ -0,0 +1,365 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package istio_test + +import ( + "fmt" + "io" + "net" + "net/http" + "strconv" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" +) + +// testContext encapsulates the state required to run tests that simulate +// an istio like environment. +// +// A diagram depicting the setup is shown below. +// +-----------------------------------------------------------------------+ +// | +-------------------------------------------------+ | +// | + ----------+ | + -----------------+ PROXY +----------+ | | +// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | | +// | + ----------+ | + -----------------+ +----------+ | | | +// | | -------|-------------+ +----------+ | | | +// | | | | | proxyEP |-+ | | +// | +-----redirect | +----------+ | | +// | + ------------+---|------+---+ | +// | | | +// | Local Stack. | | +// +-------------------------------------------------------|---------------+ +// | +// +-----------------------------------------------------------------------+ +// | remoteStack | | +// | +-------------SYN ---------------| | +// | | | | +// | +-------------------|--------------------------------|-_---+ | +// | | + -----------------+ + ----------+ | | | +// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | | +// | | + -----------------+ + ----------+ | | +// | | Remote HTTP Server | | +// | +----------------------------------------------------------+ | +// +-----------------------------------------------------------------------+ +// +type testContext struct { + // localServerListener is the listening port for the server which will proxy + // all traffic to the remote EP. + localServerListener *gonet.TCPListener + + // remoteListenListener is the remote listening endpoint that will receive + // connections from server. + remoteServerListener *gonet.TCPListener + + // localStack is the stack used to create client/server endpoints and + // also the stack on which we install NAT redirect rules. + localStack *stack.Stack + + // remoteStack is the stack that represents a *remote* server. + remoteStack *stack.Stack + + // defaultResponse is the response served by the HTTP server for all GET + defaultResponse []byte + + // requests. wg is used to wait for HTTP server and Proxy to terminate before + // returning from cleanup. + wg sync.WaitGroup +} + +func (ctx *testContext) cleanup() { + ctx.localServerListener.Close() + ctx.localStack.Close() + ctx.remoteServerListener.Close() + ctx.remoteStack.Close() + ctx.wg.Wait() +} + +const ( + localServerPort = 8080 + remoteServerPort = 9090 +) + +var ( + localIPv4Addr1 = testutil.MustParse4("10.0.0.1") + localIPv4Addr2 = testutil.MustParse4("10.0.0.2") + loopbackIPv4Addr = testutil.MustParse4("127.0.0.1") + remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3") +) + +func newTestContext(t *testing.T) *testContext { + t.Helper() + localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */) + + localStack := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: true, + }) + + remoteStack := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: true, + }) + + // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to + // loopback address + specified port. + loopbackNIC := loopback.New() + const loopbackNICID = tcpip.NICID(1) + if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil { + t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err) + } + loopbackAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: loopbackIPv4Addr.WithPrefix(), + } + if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err) + } + + // Create linked NICs that connects the local and remote stack. + const localNICID = tcpip.NICID(2) + const remoteNICID = tcpip.NICID(3) + if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil { + t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err) + } + if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil { + t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err) + } + + for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} { + localProtocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err) + } + } + + remoteProtocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: remoteIPv4Addr1.WithPrefix(), + } + if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err) + } + + // Setup route table for local and remote stacks. + localStack.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4LoopbackSubnet, + NIC: loopbackNICID, + }, + { + Destination: header.IPv4EmptySubnet, + NIC: localNICID, + }, + }) + remoteStack.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: remoteNICID, + }, + }) + + const netProto = ipv4.ProtocolNumber + localServerAddress := tcpip.FullAddress{ + Port: localServerPort, + } + + localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto) + if err != nil { + t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err) + } + + remoteServerAddress := tcpip.FullAddress{ + Port: remoteServerPort, + } + remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto) + if err != nil { + t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err) + } + + // Initialize a random default response served by the HTTP server. + defaultResponse := make([]byte, 512<<10) + if _, err := rand.Read(defaultResponse); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + + tc := &testContext{ + localServerListener: localServerListener, + remoteServerListener: remoteServerListener, + localStack: localStack, + remoteStack: remoteStack, + defaultResponse: defaultResponse, + } + + tc.startServers(t) + return tc +} + +func (ctx *testContext) startServers(t *testing.T) { + ctx.wg.Add(1) + go func() { + defer ctx.wg.Done() + ctx.startHTTPServer() + }() + ctx.wg.Add(1) + go func() { + defer ctx.wg.Done() + ctx.startTCPProxyServer(t) + }() +} + +func (ctx *testContext) startTCPProxyServer(t *testing.T) { + t.Helper() + for { + conn, err := ctx.localServerListener.Accept() + if err != nil { + t.Logf("terminating local proxy server: %s", err) + return + } + // Start a goroutine to handle this inbound connection. + go func() { + remoteServerAddr := tcpip.FullAddress{ + Addr: remoteIPv4Addr1, + Port: remoteServerPort, + } + localServerAddr := tcpip.FullAddress{ + Addr: localIPv4Addr2, + } + serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber) + if err != nil { + t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err) + return + } + proxy(conn, serverConn) + t.Logf("proxying completed") + }() + } +} + +// proxy transparently proxies the TCP payload from conn1 to conn2 +// and vice versa. +func proxy(conn1, conn2 net.Conn) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + io.Copy(conn2, conn1) + conn1.Close() + conn2.Close() + }() + wg.Add(1) + go func() { + io.Copy(conn1, conn2) + conn1.Close() + conn2.Close() + }() + wg.Wait() +} + +func (ctx *testContext) startHTTPServer() { + handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(ctx.defaultResponse)) + }) + s := &http.Server{ + Handler: handlerFunc, + } + s.Serve(ctx.remoteServerListener) +} + +func TestOutboundNATRedirect(t *testing.T) { + ctx := newTestContext(t) + defer ctx.cleanup() + + // Install an IPTable rule to redirect all TCP traffic with the sourceIP of + // localIPv4Addr1 to the tcp proxy port. + ipt := ctx.localStack.IPTables() + tbl := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := tbl.BuiltinChains[stack.Output] + tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + Protocol: tcp.ProtocolNumber, + CheckProtocol: true, + Src: localIPv4Addr1, + SrcMask: tcpip.Address("\xff\xff\xff\xff"), + } + tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{ + Port: localServerPort, + NetworkProtocol: ipv4.ProtocolNumber, + } + tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err) + } + + dialFunc := func(protocol, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err) + } + + remoteServerIP := net.ParseIP(host) + remoteServerPort, err := strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err) + } + remoteAddress := tcpip.FullAddress{ + Addr: tcpip.Address(remoteServerIP.To4()), + Port: uint16(remoteServerPort), + } + + // Dial with an explicit source address bound so that the redirect rule will + // be able to correctly redirect these packets. + localAddr := tcpip.FullAddress{Addr: localIPv4Addr1} + return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + Dial: dialFunc, + }, + } + + serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort) + response, err := httpClient.Get(serverURL) + if err != nil { + t.Fatalf("httpClient.Get(\"/\") failed: %s", err) + } + if got, want := response.StatusCode, http.StatusOK; got != want { + t.Fatalf("unexpected status code got: %d, want: %d", got, want) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("io.ReadAll(response.Body) failed: %s", err) + } + response.Body.Close() + if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" { + t.Fatalf("unexpected response (-want +got): \n %s", diff) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 27caa0c28..95ddd8ec3 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ @@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.incomingAddr, } - if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err) } // Set up endpoint through which we will attempt to forward packets. @@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.outgoingAddr, } - if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index b2008f0b2..f33223e79 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ { @@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + v4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err) + } + v6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err) } if err := s.CreateNIC(nicID2, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ipv4Loopback, + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: header.IPv6Loopback.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if test.forwarding { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 2d0a6e6a7..7753e7d6e 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err) } // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote @@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } var wq waiter.Queue @@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } // Set the route table so that UDP can find a NIC that is diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index ac3c703d4..422eb8408 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) { // request/reply packets. icmpDataOffset = 8 ) - ipv4Loopback := testutil.MustParse4("127.0.0.1") + ipv4Loopback := tcpip.AddressWithPrefix{ + Address: testutil.MustParse4("127.0.0.1"), + PrefixLen: 8, + } channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { @@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) { transProto tcpip.TransportProtocolNumber netProto tcpip.NetworkProtocolNumber linkEndpoint func() stack.LinkEndpoint - localAddr tcpip.Address + localAddr tcpip.AddressWithPrefix icmpBuf func(*testing.T) buffer.View expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) @@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, - localAddr: header.IPv6Loopback, + localAddr: header.IPv6Loopback.WithPrefix(), icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, @@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv4Addr.Address, + localAddr: utils.Ipv4Addr, icmpBuf: ipv4ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv6Addr.Address, + localAddr: utils.Ipv6Addr, icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + if len(test.localAddr.Address) != 0 { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.netProto, + AddressWithPrefix: test.localAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) { } defer ep.Close() - connAddr := tcpip.FullAddress{Addr: test.localAddr} + connAddr := tcpip.FullAddress{Addr: test.localAddr.Address} if err := ep.Connect(connAddr); err != test.expectedConnectErr { t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) } @@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) { if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { t.Errorf("received data mismatch (-want +got):\n%s", diff) } - if rr.RemoteAddr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + if rr.RemoteAddr.Addr != test.localAddr.Address { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address) } test.checkLinkEndpoint(t, e) @@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) { } if subTest.addAddress { - if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err) } - if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 2e6ae55ea..c69410859 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -40,6 +40,14 @@ const ( Host2NICID = 4 ) +// Common NIC names used by tests. +const ( + Host1NICName = "host1NIC" + RouterNIC1Name = "routerNIC1" + RouterNIC2Name = "routerNIC2" + Host2NICName = "host2NIC" +) + // Common link addresses used by tests. const ( LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") @@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) + { + opts := stack.NICOptions{Name: Host1NICName} + if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil { + t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) + { + opts := stack.NICOptions{Name: RouterNIC1Name} + if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) + { + opts := stack.NICOptions{Name: RouterNIC2Name} + if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err) + } } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) + { + opts := stack.NICOptions{Name: Host2NICName} + if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil { + t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err) + } } if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -231,29 +251,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err) } host1Stack.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index bbc0e3ecc..4718ec4ec 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/waiter", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 00497bf07..995f58616 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,7 @@ package icmp import ( + "fmt" "io" "time" @@ -24,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -35,15 +38,6 @@ type icmpPacket struct { receivedAt time.Time `state:".(int64)"` } -type endpointState int - -const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed -) - // endpoint represents an ICMP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -51,14 +45,17 @@ const ( // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -70,38 +67,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - state endpointState - route *stack.Route `state:"manual"` - ttl uint8 - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool + ident uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, - state: stateInitial, uniqueID: s.UniqueID(), } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) + ep.net.Init(s, netProto, transProto, &ep.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -128,35 +110,40 @@ func (e *endpoint) Abort() { // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { - e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { - case stateBound, stateConnected: - bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) - } - - // Close the receive list and drain it. - e.rcvMu.Lock() - e.rcvClosed = true - e.rcvBufSize = 0 - for !e.rcvList.Empty() { - p := e.rcvList.Front() - e.rcvList.Remove(p) - } - e.rcvMu.Unlock() + notify := func() bool { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + return false + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + info := e.net.Info() + info.ID.LocalPort = e.ident + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice())) + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } - if e.route != nil { - e.route.Release() - e.route = nil - } + e.net.Shutdown() + e.net.Close() - // Update the state. - e.state = stateClosed + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } - e.mu.Unlock() + return true + }() - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + if notify { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + } } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. @@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -193,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: p.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.receivedAt.UnixNano(), + Timestamp: p.receivedAt, }, } if opts.NeedRemoteAddr { @@ -213,14 +200,13 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -// +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.state { - case stateInitial: - case stateConnected: +// +checklocksread:e.mu +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - - case stateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { - return 0, err + return network.WriteContext{}, 0, err } if !retry { @@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } } - route := e.route - if to != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := to.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - - dst, netProto, err := e.checkV4MappedLocked(*to) - if err != nil { - return 0, err - } - - // Find the endpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) - if err != nil { - return 0, err - } - defer r.Release() + ctx, err := e.net.AcquireContextForWrite(opts) + return ctx, e.ident, err +} - route = r +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + ctx, ident, err := e.prepareForWrite(opts) + if err != nil { + return 0, err } + defer ctx.Release() // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) @@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return 0, &tcpip.ErrBadBuffer{} } - var err tcpip.Error - switch e.NetProto { + switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) + if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } case header.IPv6ProtocolNumber: - err = send6(route, e.ID.LocalPort, v, e.ttl) - } - - if err != nil { - return 0, err + if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } + default: + panic(fmt.Sprintf("unhandled network protocol = %d", netProto)) } return int64(len(v)), nil @@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool { return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt sets a socket option. -func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return nil +// SetSockOpt implements tcpip.Endpoint. +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + return e.net.SetSockOpt(opt) } -// SetSockOptInt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - } - return nil + return e.net.SetSockOptInt(opt, v) } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.rcvMu.Lock() - v := int(e.ttl) - e.rcvMu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +// GetSockOpt implements tcpip.Endpoint. +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { +func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength), }) - pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber @@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest - icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data().AppendView(data) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V4.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { +func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength), }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) @@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest pkt.Data().AppendView(data) dataRange := pkt.Data().AsRange() icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpv6, - Src: r.LocalAddress(), - Dst: r.RemoteAddress(), + Src: src, + Dst: dst, PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() + return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() tcpip.Error { return &tcpip.ErrNotSupported{} @@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - localPort := uint16(0) - switch e.state { - case stateInitial: - case stateBound, stateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.ident - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + nextID, err := e.registerWithStack(netProto, nextID) + if err != nil { + return err } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress(), - LocalPort: localPort, - RemoteAddress: r.RemoteAddress(), - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - id, err = e.registerWithStack(nicID, netProtos, id) + e.ident = nextID.LocalPort + return nil + }) if err != nil { - r.Release() return err } - e.ID = id - e.route = r - e.RegisterNICID = nicID - - e.state = stateConnected - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - e.shutdownFlags |= flags - if e.state != stateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } + + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } } if flags&tcpip.ShutdownRead != 0 { @@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) - return id, err + return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) + err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err := e.registerWithStack(boundNetProto, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, err = e.registerWithStack(addr.NIC, netProtos, id) + e.ident = id.LocalPort + return nil + }) if err != nil { return err } - e.ID = id - e.RegisterNICID = addr.NIC - - // Mark endpoint as bound. - e.state = stateBound - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -688,21 +567,24 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { return nil } +func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || + header.IsV4MulticastAddress(addr) || + header.IsV6MulticastAddress(addr) || + e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr) +} + // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - err := e.bindLocked(addr) - if err != nil { - return err + if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) { + return &tcpip.ErrBadLocalAddress{} } - e.BindNICID = addr.NIC - e.BindAddr = addr.Addr + e.mu.Lock() + defer e.mu.Unlock() - return nil + return e.bindLocked(addr) } // GetLocalAddress returns the address to which the endpoint is bound. @@ -710,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.ident + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -722,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { - return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} + if addr, connected := e.net.GetRemoteAddress(); connected { + return addr, nil } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness returns the current readiness of the endpoint. For example, if @@ -755,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. - switch e.NetProto { + switch e.net.NetProto() { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { @@ -829,9 +705,9 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + defer e.mu.RUnlock() + ret := e.net.Info() + ret.ID.LocalPort = e.ident return &ret } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index b8b839e4a..dfe453ff9 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,13 @@ package icmp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.thaw() + + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - if e.state != stateBound && e.state != stateConnected { - return - } - - var err tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + var err tcpip.Error + info := e.net.Info() + info.ID.LocalPort = e.ident + info.ID, err = e.registerWithStack(info.NetProto, info.ID) if err != nil { - panic(err) + panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err)) } - - e.ID.LocalAddress = e.route.LocalAddress() - } else if len(e.ID.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) - if err != nil { - panic(err) + e.ident = info.ID.LocalPort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go index cc950cbde..729f50e9a 100644 --- a/pkg/tcpip/transport/icmp/icmp_test.go +++ b/pkg/tcpip/transport/icmp/icmp_test.go @@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.AddRoute(tcpip.Route{ diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index b1edce39b..3818cb04e 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -9,6 +9,7 @@ go_library( "endpoint_state.go", ], visibility = [ + "//pkg/tcpip/transport/icmp:__pkg__", "//pkg/tcpip/transport/raw:__pkg__", "//pkg/tcpip/transport/udp:__pkg__", ], diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 09b629022..fb31e5104 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -38,31 +38,65 @@ type Endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - // state holds a transport.DatagramBasedEndpointState. - // - // state must be read from/written to atomically. - state uint32 - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu wasBound bool - info stack.TransportEndpointInfo // owner is the owner of transmitted packets. - owner tcpip.PacketOwner - writeShutdown bool - effectiveNetProto tcpip.NetworkProtocolNumber - connectedRoute *stack.Route `state:"manual"` + // + // +checklocks:mu + owner tcpip.PacketOwner + // +checklocks:mu + writeShutdown bool + // +checklocks:mu + effectiveNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu + connectedRoute *stack.Route `state:"manual"` + // +checklocks:mu multicastMemberships map[multicastMembership]struct{} // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu ttl uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastTTL uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastAddr tcpip.Address // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastNICID tcpip.NICID - ipv4TOS uint8 - ipv6TClass uint8 + // +checklocks:mu + ipv4TOS uint8 + // +checklocks:mu + ipv6TClass uint8 + + // Lock ordering: mu > infoMu. + infoMu sync.RWMutex `state:"nosave"` + // info has a dedicated mutex so that we can avoid lock ordering violations + // when reading the endpoint's info. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling Info() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setInfo. + // + // +checklocks:infoMu + info stack.TransportEndpointInfo + + // state holds a transport.DatagramBasedEndpointState. + // + // state must be accessed with atomics so that we can avoid lock ordering + // violations when reading the state. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling State() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setEndpointState. + // + // +checkatomics + state uint32 } // +stateify savable @@ -73,8 +107,11 @@ type multicastMembership struct { // Init initializes the endpoint. func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { - if e.multicastMemberships != nil { - panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships)) + e.mu.Lock() + memberships := e.multicastMemberships + e.mu.Unlock() + if memberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships)) } switch netProto { @@ -89,8 +126,6 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr netProto: netProto, transProto: transProto, - state: uint32(transport.DatagramEndpointStateInitial), - info: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: transProto, @@ -100,6 +135,10 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), } + + e.mu.Lock() + defer e.mu.Unlock() + e.setEndpointState(transport.DatagramEndpointStateInitial) } // NetProto returns the network protocol the endpoint was initialized with. @@ -107,7 +146,12 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { return e.netProto } -// setState sets the state of the endpoint. +// setEndpointState sets the state of the endpoint. +// +// e.mu must be held to synchronize changes to state with the rest of the +// endpoint. +// +// +checklocks:e.mu func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { atomic.StoreUint32(&e.state, uint32(state)) } @@ -242,23 +286,24 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if nicID == 0 { nicID = tcpip.NICID(e.ops.GetBindToDevice()) } - if e.info.BindNICID != 0 { - if nicID != 0 && nicID != e.info.BindNICID { + info := e.Info() + if info.BindNICID != 0 { + if nicID != 0 && nicID != info.BindNICID { return WriteContext{}, &tcpip.ErrNoRoute{} } - nicID = e.info.BindNICID + nicID = info.BindNICID } if nicID == 0 { - nicID = e.info.RegisterNICID + nicID = info.RegisterNICID } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) + dst, netProto, err := e.checkV4Mapped(*opts.To) if err != nil { return WriteContext{}, err } - route, _, err = e.connectRoute(nicID, dst, netProto) + route, _, err = e.connectRouteRLocked(nicID, dst, netProto) if err != nil { return WriteContext{}, err } @@ -297,26 +342,30 @@ func (e *Endpoint) Disconnect() { return } + info := e.Info() // Exclude ephemerally bound endpoints. if e.wasBound { - e.info.ID = stack.TransportEndpointID{ - LocalAddress: e.info.BindAddr, + info.ID = stack.TransportEndpointID{ + LocalAddress: info.BindAddr, } e.setEndpointState(transport.DatagramEndpointStateBound) } else { - e.info.ID = stack.TransportEndpointID{} + info.ID = stack.TransportEndpointID{} e.setEndpointState(transport.DatagramEndpointStateInitial) } + e.setInfo(info) e.connectedRoute.Release() e.connectedRoute = nil } -// connectRoute establishes a route to the specified interface or the +// connectRouteRLocked establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { - localAddr := e.info.ID.LocalAddress +// +// +checklocksread:e.mu +func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { + localAddr := e.Info().ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" @@ -359,42 +408,43 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.mu.Lock() defer e.mu.Unlock() + info := e.Info() nicID := addr.NIC switch e.State() { case transport.DatagramEndpointStateInitial: case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: - if e.info.BindNICID == 0 { + if info.BindNICID == 0 { break } - if nicID != 0 && nicID != e.info.BindNICID { + if nicID != 0 && nicID != info.BindNICID { return &tcpip.ErrInvalidEndpointState{} } - nicID = e.info.BindNICID + nicID = info.BindNICID default: return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } - r, nicID, err := e.connectRoute(nicID, addr, netProto) + r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto) if err != nil { return err } id := stack.TransportEndpointID{ - LocalAddress: e.info.ID.LocalAddress, + LocalAddress: info.ID.LocalAddress, RemoteAddress: r.RemoteAddress(), } if e.State() == transport.DatagramEndpointStateInitial { id.LocalAddress = r.LocalAddress() } - if err := f(r.NetProto(), e.info.ID, id); err != nil { + if err := f(r.NetProto(), info.ID, id); err != nil { return err } @@ -403,8 +453,9 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.connectedRoute.Release() } e.connectedRoute = r - e.info.ID = id - e.info.RegisterNICID = nicID + info.ID = id + info.RegisterNICID = nicID + e.setInfo(info) e.effectiveNetProto = netProto e.setEndpointState(transport.DatagramEndpointStateConnected) return nil @@ -426,10 +477,11 @@ func (e *Endpoint) Shutdown() tcpip.Error { } } -// checkV4MappedLocked determines the effective network protocol and converts +// checkV4MappedRLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) +func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { + info := e.Info() + unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -464,7 +516,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } @@ -483,12 +535,14 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto e.wasBound = true - e.info.ID = stack.TransportEndpointID{ + info := e.Info() + info.ID = stack.TransportEndpointID{ LocalAddress: addr.Addr, } - e.info.BindNICID = addr.NIC - e.info.RegisterNICID = nicID - e.info.BindAddr = addr.Addr + info.BindNICID = addr.NIC + info.RegisterNICID = nicID + info.BindAddr = addr.Addr + e.setInfo(info) e.effectiveNetProto = netProto e.setEndpointState(transport.DatagramEndpointStateBound) return nil @@ -506,13 +560,14 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { e.mu.RLock() defer e.mu.RUnlock() - addr := e.info.BindAddr + info := e.Info() + addr := info.BindAddr if e.State() == transport.DatagramEndpointStateConnected { addr = e.connectedRoute.LocalAddress() } return tcpip.FullAddress{ - NIC: e.info.RegisterNICID, + NIC: info.RegisterNICID, Addr: addr, } } @@ -528,7 +583,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { return tcpip.FullAddress{ Addr: e.connectedRoute.RemoteAddress(), - NIC: e.info.RegisterNICID, + NIC: e.Info().RegisterNICID, }, true } @@ -610,7 +665,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) + fa, netProto, err := e.checkV4Mapped(fa) if err != nil { return err } @@ -634,7 +689,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } - if e.info.BindNICID != 0 && e.info.BindNICID != nic { + if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic { return &tcpip.ErrInvalidEndpointState{} } @@ -737,7 +792,19 @@ func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { // Info returns a copy of the endpoint info. func (e *Endpoint) Info() stack.TransportEndpointInfo { - e.mu.RLock() - defer e.mu.RUnlock() + e.infoMu.RLock() + defer e.infoMu.RUnlock() return e.info } + +// setInfo sets the endpoint's info. +// +// e.mu must be held to synchronize changes to info with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) { + e.infoMu.Lock() + defer e.infoMu.Unlock() + e.info = info +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go index 858007156..68bd1fbf6 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_state.go +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -35,20 +35,22 @@ func (e *Endpoint) Resume(s *stack.Stack) { } } + info := e.Info() + switch state := e.State(); state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: case transport.DatagramEndpointStateBound: - if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { - if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 { - panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress)) + if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) { + if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 { + panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress)) } } case transport.DatagramEndpointStateConnected: var err tcpip.Error multicastLoop := e.ops.GetMulticastLoop() - e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) + e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) if err != nil { - panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) + panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) } default: panic(fmt.Sprintf("unhandled state = %s", state)) diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go index d99c961c3..f263a9ea2 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_test.go +++ b/pkg/tcpip/transport/internal/network/endpoint_test.go @@ -124,11 +124,20 @@ func TestEndpointStateTransitions(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err) + } + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), + } + + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -257,11 +266,19 @@ func TestBindNICID(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err) + } + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) } var ops tcpip.SocketOptions diff --git a/pkg/tcpip/transport/internal/noop/BUILD b/pkg/tcpip/transport/internal/noop/BUILD new file mode 100644 index 000000000..171c41eb1 --- /dev/null +++ b/pkg/tcpip/transport/internal/noop/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "noop", + srcs = ["endpoint.go"], + visibility = ["//pkg/tcpip/transport/raw:__pkg__"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/internal/noop/endpoint.go b/pkg/tcpip/transport/internal/noop/endpoint.go new file mode 100644 index 000000000..443b4e416 --- /dev/null +++ b/pkg/tcpip/transport/internal/noop/endpoint.go @@ -0,0 +1,172 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package noop contains an endpoint that implements all tcpip.Endpoint +// functions as noops. +package noop + +import ( + "fmt" + "io" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// endpoint can be created, but all interactions have no effect or +// return errors. +// +// +stateify savable +type endpoint struct { + tcpip.DefaultSocketOptionsHandler + ops tcpip.SocketOptions +} + +// New returns an initialized noop endpoint. +func New(stk *stack.Stack) tcpip.Endpoint { + // ep.ops must be in a valid, initialized state for callers of + // ep.SocketOptions. + var ep endpoint + ep.ops.InitHandler(&ep, stk, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) + return &ep +} + +// Abort implements stack.TransportEndpoint.Abort. +func (*endpoint) Abort() { + // No-op. +} + +// Close implements tcpip.Endpoint.Close. +func (*endpoint) Close() { + // No-op. +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (*endpoint) ModerateRecvBuf(int) { + // No-op. +} + +func (*endpoint) SetOwner(tcpip.PacketOwner) { + // No-op. +} + +// Read implements tcpip.Endpoint.Read. +func (*endpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { + return tcpip.ReadResult{}, &tcpip.ErrNotPermitted{} +} + +// Write implements tcpip.Endpoint.Write. +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { + return 0, &tcpip.ErrNotPermitted{} +} + +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +// Connect implements tcpip.Endpoint.Connect. +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotPermitted{} +} + +// Shutdown implements tcpip.Endpoint.Shutdown. +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { + return &tcpip.ErrNotPermitted{} +} + +// Listen implements tcpip.Endpoint.Listen. +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +// Accept implements tcpip.Endpoint.Accept. +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} +} + +// Bind implements tcpip.Endpoint.Bind. +func (*endpoint) Bind(tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotPermitted{} +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (*endpoint) Readiness(waiter.EventMask) waiter.EventMask { + return 0 +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} +} + +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (*endpoint) GetSockOptInt(tcpip.SockOptInt) (int, tcpip.Error) { + return 0, &tcpip.ErrUnknownProtocolOption{} +} + +// HandlePacket implements stack.RawTransportEndpoint.HandlePacket. +func (*endpoint) HandlePacket(pkt *stack.PacketBuffer) { + panic(fmt.Sprintf("unreachable: noop.endpoint should never be registered, but got packet: %+v", pkt)) +} + +// State implements socket.Socket.State. +func (*endpoint) State() uint32 { + return 0 +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() { + // No-op. +} + +// LastError implements tcpip.Endpoint.LastError. +func (*endpoint) LastError() tcpip.Error { + return nil +} + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { + return &ep.ops +} + +// Info implements tcpip.Endpoint.Info. +func (*endpoint) Info() tcpip.EndpointInfo { + return &stack.TransportEndpointInfo{} +} + +// Stats returns a pointer to the endpoint stats. +func (*endpoint) Stats() tcpip.EndpointStats { + return &tcpip.TransportEndpointStats{} +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0554d2f4a..80eef39e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -59,52 +59,47 @@ type packet struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue cooked bool - - // The following fields are used to manage the receive queue and are - // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + + // The following fields are used to manage the receive queue. + rcvMu sync.Mutex `state:"nosave"` + // +checklocks:rcvMu + rcvList packetList + // +checklocks:rcvMu rcvBufSize int - rcvClosed bool - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + // +checklocks:rcvMu + rcvClosed bool + // +checklocks:rcvMu + rcvDisabled bool + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + closed bool + // +checklocks:mu + boundNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu boundNIC tcpip.NICID - // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` - lastError tcpip.Error - - // ops is used to get socket level options. - ops tcpip.SocketOptions - - // frozen indicates if the packets should be delivered to the endpoint - // during restore. - frozen bool + // +checklocks:lastErrorMu + lastError tcpip.Error } // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, + stack: s, + cooked: cooked, + boundNetProto: netProto, + waiterQueue: waiterQueue, } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) @@ -140,7 +135,7 @@ func (ep *endpoint) Close() { return } - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) ep.rcvMu.Lock() defer ep.rcvMu.Unlock() @@ -153,7 +148,6 @@ func (ep *endpoint) Close() { } ep.closed = true - ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -188,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul Total: packet.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: packet.receivedAt.UnixNano(), + Timestamp: packet.receivedAt, }, } if opts.NeedRemoteAddr { @@ -214,13 +208,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC + proto := ep.boundNetProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} } var remote tcpip.LinkAddress - proto := ep.netProto if to := opts.To; to != nil { remote = tcpip.LinkAddress(to.Addr) @@ -296,29 +290,42 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound && ep.boundNIC == addr.NIC { - // If the NIC being bound is the same then just return success. + netProto := tcpip.NetworkProtocolNumber(addr.Port) + if netProto == 0 { + // Do not allow unbinding the network protocol. + netProto = ep.boundNetProto + } + + if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto { + // Already bound to the requested NIC and network protocol. return nil } - // Unregister endpoint with all the nics. - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) - ep.bound = false + // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new + // binding. + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) + ep.boundNIC = 0 + ep.boundNetProto = 0 // Bind endpoint to receive packets from specific interface. - if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } - ep.bound = true ep.boundNIC = addr.NIC - + ep.boundNetProto = netProto return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + ep.mu.RLock() + defer ep.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: ep.boundNIC, + Port: uint16(ep.boundNetProto), + }, nil } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. @@ -402,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -414,7 +421,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } rcvBufSize := ep.ops.GetReceiveBufferSize() - if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { + if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -473,10 +480,8 @@ func (*endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() - // Make a copy of the endpoint info. - ret := ep.TransportEndpointInfo - ep.mu.RUnlock() - return &ret + defer ep.mu.RUnlock() + return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto} } // Stats returns a pointer to the endpoint stats. @@ -491,18 +496,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } - -// freeze prevents any more packets from being delivered to the endpoint. -func (ep *endpoint) freeze() { - ep.mu.Lock() - ep.frozen = true - ep.mu.Unlock() -} - -// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows -// new packets to be delivered again. -func (ep *endpoint) thaw() { - ep.mu.Lock() - ep.frozen = false - ep.mu.Unlock() -} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 5c688d286..88cd80ad3 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -44,17 +45,24 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - ep.freeze() + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + ep.rcvDisabled = true } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - ep.thaw() + ep.mu.Lock() + defer ep.mu.Unlock() + ep.stack = stack.StackFromEnv ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. - if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(err) + if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil { + panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err)) } + + ep.rcvMu.Lock() + ep.rcvDisabled = false + ep.rcvMu.Unlock() } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index b7e97e218..10b0c35fb 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -35,6 +35,7 @@ go_library( "//pkg/tcpip/stack", "//pkg/tcpip/transport", "//pkg/tcpip/transport/internal/network", + "//pkg/tcpip/transport/internal/noop", "//pkg/tcpip/transport/packet", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 3040a445b..ce76774af 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -49,6 +49,7 @@ type rawPacket struct { receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + packetInfo tcpip.IPPacketInfo } // endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to @@ -70,7 +71,7 @@ type endpoint struct { associated bool net network.Endpoint - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats ops tcpip.SocketOptions // The following fields are used to manage the receive queue and are @@ -202,12 +203,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: pkt.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: pkt.receivedAt.UnixNano(), + Timestamp: pkt.receivedAt, }, } if opts.NeedRemoteAddr { res.RemoteAddr = pkt.senderAddr } + switch netProto := e.net.NetProto(); netProto { + case header.IPv4ProtocolNumber: + if e.ops.GetReceivePacketInfo() { + res.ControlMessages.HasIPPacketInfo = true + res.ControlMessages.PacketInfo = pkt.packetInfo + } + case header.IPv6ProtocolNumber: + if e.ops.GetIPv6ReceivePacketInfo() { + res.ControlMessages.HasIPv6PacketInfo = true + res.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{ + NIC: pkt.packetInfo.NIC, + Addr: pkt.packetInfo.DestinationAddr, + } + } + default: + panic(fmt.Sprintf("unrecognized network protocol = %d", netProto)) + } n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { @@ -435,7 +453,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return false } - srcAddr := pkt.Network().SourceAddress() + net := pkt.Network() + dstAddr := net.DestinationAddress() + srcAddr := net.SourceAddress() info := e.net.Info() switch state := e.net.State(); state { @@ -457,7 +477,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } // If bound to an address, only accept data for that address. - if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() { + if info.BindAddr != "" && info.BindAddr != dstAddr { return false } default: @@ -472,6 +492,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { NIC: pkt.NICID, Addr: srcAddr, }, + packetInfo: tcpip.IPPacketInfo{ + // TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast + // address. LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + LocalAddr: dstAddr, + DestinationAddr: dstAddr, + NIC: pkt.NICID, + }, } // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. @@ -483,10 +511,10 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // overlapping slices. var combinedVV buffer.VectorisedView if info.NetProto == header.IPv4ProtocolNumber { - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - headers := make(buffer.View, 0, len(network)+len(transport)) - headers = append(headers, network...) - headers = append(headers, transport...) + networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader)) + headers = append(headers, networkHeader...) + headers = append(headers, transportHeader...) combinedVV = headers.ToVectorisedView() } else { combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index e393b993d..624e2dbe7 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -17,6 +17,7 @@ package raw import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/noop" "gvisor.dev/gvisor/pkg/tcpip/transport/packet" "gvisor.dev/gvisor/pkg/waiter" ) @@ -33,3 +34,18 @@ func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpi func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) } + +// CreateOnlyFactory implements stack.RawFactory. It allows creation of raw +// endpoints that do not support reading, writing, binding, etc. +type CreateOnlyFactory struct{} + +// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. +func (CreateOnlyFactory) NewUnassociatedEndpoint(stk *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return noop.New(stk), nil +} + +// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. +func (CreateOnlyFactory) NewPacketEndpoint(*stack.Stack, bool, tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + // This isn't needed by anything, so it isn't implemented. + return nil, &tcpip.ErrNotPermitted{} +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 5148fe157..20958d882 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -80,9 +80,10 @@ go_library( go_test( name = "tcp_x_test", - size = "medium", + size = "large", srcs = [ "dual_stack_test.go", + "rcv_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", "tcp_rack_test.go", @@ -114,16 +115,6 @@ go_test( ) go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( name = "tcp_test", size = "small", srcs = [ diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 03c9fafa1..caf14b0dc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -15,12 +15,12 @@ package tcp import ( + "container/list" "crypto/sha1" "encoding/binary" "fmt" "hash" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -100,18 +100,6 @@ type listenContext struct { // netProto indicates the network protocol(IPv4/v6) for the listening // endpoint. netProto tcpip.NetworkProtocolNumber - - // pendingMu protects pendingEndpoints. This should only be accessed - // by the listening endpoint's worker goroutine. - // - // Lock Ordering: listenEP.workerMu -> pendingMu - pendingMu sync.Mutex - // pending is used to wait for all pendingEndpoints to finish when - // a socket is closed. - pending sync.WaitGroup - // pendingEndpoints is a map of all endpoints for which a handshake is - // in progress. - pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 { // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stk, - protocol: protocol, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6Only: v6Only, - netProto: netProto, - listenEP: listenEP, - pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), + stack: stk, + protocol: protocol, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6Only: v6Only, + netProto: netProto, + listenEP: listenEP, } for i := range l.nonce { @@ -193,14 +180,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } -func (l *listenContext) useSynCookies() bool { - var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies - if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { - panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) - } - return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) -} - // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { @@ -273,18 +252,15 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu return nil, &tcpip.ErrConnectionAborted{} } - l.addPendingEndpoint(ep) // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. - l.listenEP.propagateInheritableOptionsLocked(ep) + l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce if !ep.reserveTupleLocked() { ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) - return nil, &tcpip.ErrConnectionAborted{} } @@ -303,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - ep.drainClosingSegmentQueue() return nil, err @@ -344,39 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, return ep, nil } -func (l *listenContext) addPendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - l.pendingEndpoints[n.TransportEndpointInfo.ID] = n - l.pending.Add(1) - l.pendingMu.Unlock() -} - -func (l *listenContext) removePendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) - l.pending.Done() - l.pendingMu.Unlock() -} - -func (l *listenContext) closeAllPendingEndpoints() { - l.pendingMu.Lock() - for _, n := range l.pendingEndpoints { - n.notifyProtocolGoroutine(notifyClose) - } - l.pendingMu.Unlock() - l.pending.Wait() -} - -// Precondition: h.ep.mu must be held. // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() e.Close() e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.drainClosingSegmentQueue() e.h = nil } @@ -384,12 +329,9 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // cleanupCompletedHandshake transfers any state from the completed handshake to // the new endpoint. // -// Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.isConnectNotified = true // Update the receive window scaling. We can't do it before the @@ -401,47 +343,11 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e.h = nil } -// deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// listener has transitioned out of the listen state (accepted is the zero -// value), the new endpoint is reset instead. -func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { - e.mu.Lock() - e.pendingAccepted.Add(1) - e.mu.Unlock() - defer e.pendingAccepted.Done() - - // Drop the lock before notifying to avoid deadlock in user-specified - // callbacks. - delivered := func() bool { - e.acceptMu.Lock() - defer e.acceptMu.Unlock() - for { - if e.accepted == (accepted{}) { - return false - } - if e.accepted.endpoints.Len() == e.accepted.cap { - e.acceptCond.Wait() - continue - } - - e.accepted.endpoints.PushBack(n) - if !withSynCookie { - atomic.AddInt32(&e.synRcvdCount, -1) - } - return true - } - }() - if delivered { - e.waiterQueue.Notify(waiter.ReadableEvents) - } else { - n.notifyProtocolGoroutine(notifyReset) - } -} - // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // -// Precondition: e.mu and n.mu must be held. +// +checklocks:e.mu +// +checklocks:n.mu func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout n.portFlags = e.portFlags @@ -452,9 +358,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // reserveTupleLocked reserves an accepted endpoint's tuple. // -// Preconditions: -// * propagateInheritableOptionsLocked has been called. -// * e.mu is held. +// Precondition: e.propagateInheritableOptionsLocked has been called. +// +// +checklocks:e.mu func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{ Addr: e.TransportEndpointInfo.ID.RemoteAddress, @@ -489,70 +395,36 @@ func (e *endpoint) notifyAborted() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -// handleSynSegment is called in its own goroutine once the listening endpoint -// receives a SYN segment. It is responsible for completing the handshake and -// queueing the new endpoint for acceptance. -// -// A limited number of these goroutines are allowed before TCP starts using SYN -// cookies to accept connections. -// -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error { - defer s.decRef() +func (e *endpoint) acceptQueueIsFull() bool { + e.acceptMu.Lock() + full := e.acceptQueue.isFull() + e.acceptMu.Unlock() + return full +} - h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) - if err != nil { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - atomic.AddInt32(&e.synRcvdCount, -1) - return err - } +// +stateify savable +type acceptQueue struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` - go func() { - // Note that startHandshake returns a locked endpoint. The - // force call here just makes it so. - if err := h.complete(); err != nil { // +checklocksforce - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - ctx.cleanupFailedHandshake(h) - atomic.AddInt32(&e.synRcvdCount, -1) - return - } - ctx.cleanupCompletedHandshake(h) - h.ep.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep, false /*withSynCookie*/) - }() + // pendingEndpoints is a set of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[*endpoint]struct{} - return nil + // capacity is the maximum number of endpoints that can be in endpoints. + capacity int } -func (e *endpoint) synRcvdBacklogFull() bool { - e.acceptMu.Lock() - acceptedCap := e.accepted.cap - e.acceptMu.Unlock() - // The capacity of the accepted queue would always be one greater than the - // listen backlog. But, the SYNRCVD connections count is always checked - // against the listen backlog value for Linux parity reason. - // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 - // - // We maintain an equality check here as the synRcvdCount is incremented - // and compared only from a single listener context and the capacity of - // the accepted queue can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 -} - -func (e *endpoint) acceptQueueIsFull() bool { - e.acceptMu.Lock() - full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap - e.acceptMu.Unlock() - return full +func (a *acceptQueue) isFull() bool { + return a.endpoints.Len() == a.capacity } // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed @@ -580,11 +452,95 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } opts := parseSynSegmentOptions(s) - if !ctx.useSynCookies() { - s.incRef() - atomic.AddInt32(&e.synRcvdCount, 1) - return e.handleSynSegment(ctx, s, opts) + + useSynCookies, err := func() (bool, tcpip.Error) { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + if alwaysUseSynCookies { + return true, nil + } + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 { + return true, nil + } + + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return false, err + } + + e.acceptQueue.pendingEndpoints[h.ep] = struct{}{} + e.pendingAccepted.Add(1) + + go func() { + defer func() { + e.pendingAccepted.Done() + + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + delete(e.acceptQueue.pendingEndpoints, h.ep) + }() + + // Note that startHandshake returns a locked endpoint. The force call + // here just makes it so. + if err := h.complete(); err != nil { // +checklocksforce + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + return + } + ctx.cleanupCompletedHandshake(h) + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + + // Deliver the endpoint to the accept queue. + // + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + // The listener is transitioning out of the Listen state; bail. + if e.acceptQueue.capacity == 0 { + return false + } + if e.acceptQueue.isFull() { + e.acceptCond.Wait() + continue + } + + e.acceptQueue.endpoints.PushBack(h.ep) + return true + } + }() + + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + h.ep.notifyProtocolGoroutine(notifyReset) + } + }() + + return false, nil + }() + if err != nil { + return err + } + if !useSynCookies { + return nil } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -627,18 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil case s.flags.Contains(header.TCPFlagAck): - if e.acceptQueueIsFull() { - // Silently drop the ack as the application can't accept - // the connection at this point. The ack will be - // retransmitted by the sender anyway and we can - // complete the connection at the time of retransmit if - // the backlog has space. - e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() - e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() - e.stack.Stats().DroppedPackets.Increment() - return nil - } - iss := s.ackNumber - 1 irs := s.sequenceNumber - 1 @@ -674,6 +618,24 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // ACK was received from the sender. return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } + + // Keep hold of acceptMu until the new endpoint is in the accept queue (or + // if there is an error), to guarantee that we will keep our spot in the + // queue even if another handshake from the syn queue completes. + e.acceptMu.Lock() + if e.acceptQueue.isFull() { + // Silently drop the ack as the application can't accept + // the connection at this point. The ack will be + // retransmitted by the sender anyway and we can + // complete the connection at the time of retransmit if + // the backlog has space. + e.acceptMu.Unlock() + e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return nil + } + e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. rcvdSynOptions := header.TCPSynOptions{ @@ -695,6 +657,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { + e.acceptMu.Unlock() return err } @@ -706,6 +669,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !n.reserveTupleLocked() { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -723,6 +687,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.boundBindToDevice, ); err != nil { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -755,20 +720,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.newSegmentWaker.Assert() } - // Do the delivery in a separate goroutine so - // that we don't block the listen loop in case - // the application is slow to accept or stops - // accepting. - // - // NOTE: This won't result in an unbounded - // number of goroutines as we do check before - // entering here that there was at least some - // space available in the backlog. - // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n, true /*withSynCookie*/) + + // Deliver the endpoint to the accept queue. + e.acceptQueue.endpoints.PushBack(n) + e.acceptMu.Unlock() + + e.waiterQueue.Notify(waiter.ReadableEvents) return nil default: @@ -785,14 +745,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto) defer func() { - // Mark endpoint as closed. This will prevent goroutines running - // handleSynSegment() from attempting to queue new connections - // to the endpoint. e.setEndpointState(StateClose) - // Close any endpoints in SYN-RCVD state. - ctx.closeAllPendingEndpoints() - // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5d8e18484..80cd07218 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -30,6 +30,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// InitialRTO is the initial retransmission timeout. +// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142 +const InitialRTO = time.Second + // maxSegmentsPerWake is the maximum number of segments to process in the main // protocol goroutine per wake-up. Yielding [after this number of segments are // processed] allows other events to be processed as well (e.g., timeouts, @@ -532,7 +536,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -578,6 +582,9 @@ func (h *handshake) complete() tcpip.Error { if (n¬ifyClose)|(n¬ifyAbort) != 0 { return &tcpip.ErrAborted{} } + if n¬ifyShutdown != 0 { + return &tcpip.ErrConnectionReset{} + } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { s := h.ep.segmentQueue.dequeue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d2b8f298f..066ffe051 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,7 +15,6 @@ package tcp import ( - "container/list" "encoding/binary" "fmt" "io" @@ -187,6 +186,8 @@ const ( // say TIME_WAIT. notifyTickleWorker notifyError + // notifyShutdown means that a connecting socket was shutdown. + notifyShutdown ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -203,6 +204,8 @@ type SACKInfo struct { } // ReceiveErrors collect segment receive errors within transport layer. +// +// +stateify savable type ReceiveErrors struct { tcpip.ReceiveErrors @@ -232,6 +235,8 @@ type ReceiveErrors struct { } // SendErrors collect segment send errors within the transport layer. +// +// +stateify savable type SendErrors struct { tcpip.SendErrors @@ -255,6 +260,8 @@ type SendErrors struct { } // Stats holds statistics about the endpoint. +// +// +stateify savable type Stats struct { // SegmentsReceived is the number of TCP segments received that // the transport layer successfully parsed. @@ -309,15 +316,6 @@ type rcvQueueInfo struct { rcvQueue segmentList `state:"wait"` } -// +stateify savable -type accepted struct { - // NB: this could be an endpointList, but ilist only permits endpoints to - // belong to one list at a time, and endpoints are already stored in the - // dispatcher's list. - endpoints list.List `state:".([]*endpoint)"` - cap int -} - // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -333,7 +331,7 @@ type accepted struct { // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects accepted. +// e.acceptMu -> Protects e.acceptQueue. // e.rcvQueueMu -> Protects e.rcvQueue and associated fields. // e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -497,10 +495,6 @@ type endpoint struct { // and dropped when it is. segmentQueue segmentQueue `state:"wait"` - // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state; this is only accessed atomically. - synRcvdCount int32 - // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 @@ -573,7 +567,8 @@ type endpoint struct { // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. - accepted accepted + // +checklocks:acceptMu + acceptQueue acceptQueue // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -606,8 +601,7 @@ type endpoint struct { gso stack.GSO - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats Stats `state:"nosave"` + stats Stats // tcpLingerTimeout is the maximum amount of a time a socket // a socket stays in TIME_WAIT state before being marked @@ -819,10 +813,9 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto waiterQueue: waiterQueue, state: uint32(StateInitial), keepalive: keepalive{ - // Linux defaults. - idle: 2 * time.Hour, - interval: 75 * time.Second, - count: 9, + idle: DefaultKeepaliveIdle, + interval: DefaultKeepaliveInterval, + count: DefaultKeepaliveCount, }, uniqueID: s.UniqueID(), txHash: s.Rand().Uint32(), @@ -904,7 +897,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if e.accepted.endpoints.Len() != 0 { + if e.acceptQueue.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -1087,20 +1080,20 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - acceptedCopy := e.accepted - e.accepted = accepted{} - e.acceptMu.Unlock() - - if acceptedCopy == (accepted{}) { - return + // Close any endpoints in SYN-RCVD state. + for n := range e.acceptQueue.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) } - - e.acceptCond.Broadcast() - + e.acceptQueue.pendingEndpoints = nil // Reset all connections that are waiting to be accepted. - for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() { n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } + e.acceptQueue.endpoints.Init() + e.acceptMu.Unlock() + + e.acceptCond.Broadcast() + // Wait for reset of all endpoints that are still waiting to be delivered to // the now closed accepted. e.pendingAccepted.Wait() @@ -2060,7 +2053,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber) e.UnlockUser() if err != nil { return err @@ -2380,6 +2373,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() + + if e.EndpointState().connecting() { + // When calling shutdown(2) on a connecting socket, the endpoint must + // enter the error state. But this logic cannot belong to the shutdownLocked + // method because that method is called during a close(2) (and closing a + // connecting socket is not an error). + e.resetConnectionLocked(&tcpip.ErrConnectionReset{}) + e.notifyProtocolGoroutine(notifyShutdown) + e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) + return nil + } + return e.shutdownLocked(flags) } @@ -2480,22 +2485,23 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.accepted == (accepted{}) { - // listen is called after shutdown. - e.accepted.cap = backlog - e.shutdownFlags = 0 - e.rcvQueueInfo.rcvQueueMu.Lock() - e.rcvQueueInfo.RcvClosed = false - e.rcvQueueInfo.rcvQueueMu.Unlock() - } else { - // Adjust the size of the backlog iff we can fit - // existing pending connections into the new one. - if e.accepted.endpoints.Len() > backlog { - return &tcpip.ErrInvalidEndpointState{} - } - e.accepted.cap = backlog + + // Adjust the size of the backlog iff we can fit + // existing pending connections into the new one. + if e.acceptQueue.endpoints.Len() > backlog { + return &tcpip.ErrInvalidEndpointState{} + } + e.acceptQueue.capacity = backlog + + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) } + e.shutdownFlags = 0 + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() + // Notify any blocked goroutines that they can attempt to // deliver endpoints again. e.acceptCond.Broadcast() @@ -2530,8 +2536,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.accepted == (accepted{}) { - e.accepted.cap = backlog + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) + } + if e.acceptQueue.capacity == 0 { + e.acceptQueue.capacity = backlog } e.acceptMu.Unlock() @@ -2571,8 +2580,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. // Get the new accepted endpoint. var n *endpoint e.acceptMu.Lock() - if element := e.accepted.endpoints.Front(); element != nil { - n = e.accepted.endpoints.Remove(element).(*endpoint) + if element := e.acceptQueue.endpoints.Front(); element != nil { + n = e.acceptQueue.endpoints.Remove(element).(*endpoint) } e.acceptMu.Unlock() if n == nil { @@ -2989,6 +2998,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { } s.Sender.RACKState = e.snd.rc.TCPRACKState + s.Sender.RetransmitTS = e.snd.retransmitTS + s.Sender.SpuriousRecovery = e.snd.spuriousRecovery return s } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index f2e8b3840..94072a115 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() { } // saveEndpoints is invoked by stateify. -func (a *accepted) saveEndpoints() []*endpoint { +func (a *acceptQueue) saveEndpoints() []*endpoint { acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { acceptedEndpoints[i] = e.Value.(*endpoint) @@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint { } // loadEndpoints is invoked by stateify. -func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { +func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) { for _, ep := range acceptedEndpoints { a.endpoints.PushBack(ep) } @@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() - backlog := e.accepted.cap + e.acceptMu.Lock() + backlog := e.acceptQueue.capacity + e.acceptMu.Unlock() if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index e4410ad93..f122ea009 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -66,6 +66,18 @@ const ( // DefaultSynRetries is the default value for the number of SYN retransmits // before a connect is aborted. DefaultSynRetries = 6 + + // DefaultKeepaliveIdle is the idle time for a connection before keep-alive + // probes are sent. + DefaultKeepaliveIdle = 2 * time.Hour + + // DefaultKeepaliveInterval is the time between two successive keep-alive + // probes. + DefaultKeepaliveInterval = 75 * time.Second + + // DefaultKeepaliveCount is the number of keep-alive probes that are sent + // before declaring the connection dead. + DefaultKeepaliveCount = 9 ) const ( diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index 8a026ec46..e47a07030 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rcv_test +package tcp_test import ( "testing" diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 2e6ea06f5..2d5fdda19 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW DataSize: seg.data.Size(), SegMemSize: seg.segMemSize(), } - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("%s differs (-want +got):\n%s", name, diff) } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 2fabf1594..4377f07a0 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -144,6 +144,15 @@ type sender struct { // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm. probeTimer timer `state:"nosave"` probeWaker sleep.Waker `state:"nosave"` + + // spuriousRecovery indicates whether the sender entered recovery + // spuriously as described in RFC3522 Section 3.2. + spuriousRecovery bool + + // retransmitTS is the timestamp at which the sender sends retransmitted + // segment after entering an RTO for the first time as described in + // RFC3522 Section 3.2. + retransmitTS uint32 } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -425,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool { return true } + // Initialize the variables used to detect spurious recovery after + // entering RTO. + // + // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1. + s.spuriousRecovery = false + s.retransmitTS = 0 + // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases // when writeList is empty. Remove this once we have a proper fix for this // issue. @@ -495,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveRecovery() } + // Record retransmitTS if the sender is not in recovery as per: + // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + s.recordRetransmitTS() + s.state = tcpip.RTORecovery s.cc.HandleRTOExpired() @@ -958,6 +978,13 @@ func (s *sender) sendData() { } func (s *sender) enterRecovery() { + // Initialize the variables used to detect spurious recovery after + // entering recovery. + // + // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1. + s.spuriousRecovery = false + s.retransmitTS = 0 + s.FastRecovery.Active = true // Save state to reflect we're now in fast recovery. // @@ -972,6 +999,11 @@ func (s *sender) enterRecovery() { s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding s.FastRecovery.HighRxt = s.SndUna s.FastRecovery.RescueRxt = s.SndUna + + // Record retransmitTS if the sender is not in recovery as per: + // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + s.recordRetransmitTS() + if s.ep.SACKPermitted { s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() @@ -1147,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool { // Iterate the writeList and update RACK for each segment which is newly acked // either cumulatively or selectively. Loop through the segments which are // sacked, and update the RACK related variables and check for reordering. +// Returns true when the DSACK block has been detected in the received ACK. // // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. -func (s *sender) walkSACK(rcvdSeg *segment) { +func (s *sender) walkSACK(rcvdSeg *segment) bool { s.rc.setDSACKSeen(false) // Look for DSACK block. + hasDSACK := false idx := 0 n := len(rcvdSeg.parsedOptions.SACKBlocks) if checkDSACK(rcvdSeg) { @@ -1167,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.setDSACKSeen(true) idx = 1 n-- + hasDSACK = true } if n == 0 { - return + return hasDSACK } // Sort the SACK blocks. The first block is the most recent unacked @@ -1193,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { seg = seg.Next() } } + return hasDSACK } // checkDSACK checks if a DSACK is reported. @@ -1239,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool { return false } +func (s *sender) recordRetransmitTS() { + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 + // + // The Eifel detection algorithm is used, only upon initiation of loss + // recovery, i.e., when either the timeout-based retransmit or the fast + // retransmit is sent. The Eifel detection algorithm MUST NOT be + // reinitiated after loss recovery has already started. In particular, + // it must not be reinitiated upon subsequent timeouts for the same + // segment, and not upon retransmitting segments other than the oldest + // outstanding segment, e.g., during selective loss recovery. + if s.inRecovery() { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + // + // Set a "RetransmitTS" variable to the value of the Timestamp Value + // field of the Timestamps option included in the retransmit sent when + // loss recovery is initiated. A TCP sender must ensure that + // RetransmitTS does not get overwritten as loss recovery progresses, + // e.g., in case of a second timeout and subsequent second retransmit of + // the same octet. + s.retransmitTS = s.ep.tsValNow() +} + +func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) { + // Return if the sender has already detected spurious recovery. + if s.spuriousRecovery { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4 + // + // If the value of the Timestamp Echo Reply field of the acceptable ACK's + // Timestamps option is smaller than the value of RetransmitTS, then + // proceed to next step, else return. + if tsEchoReply >= s.retransmitTS { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5 + // + // If the acceptable ACK carries a DSACK option [RFC2883], then return. + if hasDSACK { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5 + // + // If during the lifetime of the TCP connection the TCP sender has + // previously received an ACK with a DSACK option, or the acceptable ACK + // does not acknowledge all outstanding data, then proceed to next step, + // else return. + numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value() + if numDSACK == 0 && s.SndUna == s.SndNxt { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6 + // + // If the loss recovery has been initiated with a timeout-based + // retransmit, then set + // SpuriousRecovery <- SPUR_TO (equal 1), + // else set + // SpuriousRecovery <- dupacks+1 + // Set the spurious recovery variable to true as we do not differentiate + // between fast, SACK or RTO recovery. + s.spuriousRecovery = true + s.ep.stack.Stats().TCP.SpuriousRecovery.Increment() +} + +// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state. +func (s *sender) inRecovery() bool { + if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery { + return true + } + return false +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { @@ -1254,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Insert SACKBlock information into our scoreboard. + hasDSACK := false if s.ep.SACKPermitted { for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds @@ -1288,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // RACK.fack, then the corresponding packet has been // reordered and RACK.reord is set to TRUE. if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { - s.walkSACK(rcvdSeg) + hasDSACK = s.walkSACK(rcvdSeg) } s.SetPipe() } @@ -1418,6 +1534,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Clear SACK information for all acked data. s.ep.scoreboard.Delete(s.SndUna) + // Detect if the sender entered recovery spuriously. + if s.inRecovery() { + s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr) + } + // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. if !s.FastRecovery.Active { diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c35db7c95..0d36d0dd0 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -1059,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) { for i := 0; i < numPkts; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload - if i == 0 { - // Send ACK for the first packet to establish RTT. - c.SendAck(seq, maxPayload) - } } - // SACK for #10 packet. - start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) + // Expect retransmission of last packet due to TLP. + c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize) + + // SACK for first and last packet. + start := c.IRS.Add(seqnum.Size(maxPayload)) end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}}) + dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) var info tcpip.TCPInfoOption if err := c.EP.GetSockOpt(&info); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 6255355bb..896249d2d 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -23,6 +23,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) { t.Error(err) } } + +func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) { + t.Helper() + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } +} + +func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) { + payloadLen := uint32(len(tcpHdr.Payload())) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead), + checker.TCPAckNum(context.TestInitialSequenceNumber+1), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + pdata := data[bytesRead : bytesRead+payloadLen] + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { + t.Fatalf("got data = %v, want = %v", p, pdata) + } +} + +func buildTSOptionFromHeader(tcpHdr header.TCP) []byte { + parsedOpts := tcpHdr.ParsedOptions() + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) + return tsOpt[:] +} + +func TestDetectSpuriousRecoveryWithRTO(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) + }) + + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) + } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } + + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Expect #5 segment with TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + // Expect #1 segment because of RTO. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + + if info.CcState != tcpip.RTORecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery) + } + + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone + + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */) +} + +func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + numAck := 0 + probeDone := make(chan struct{}) + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + if numAck < 3 { + numAck++ + return + } + + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) + }) + + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) + } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } + + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) + + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + + if info.CcState != tcpip.SACKRecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery) + } + + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone + + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */) +} + +func TestNoSpuriousRecoveryWithDSACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) + + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) + + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + // Acknowledge the data with DSACK for #1 segment. + start = c.IRS.Add(maxPayload + 1) + end = start.Add(2 * maxPayload) + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}}) + + verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index bc8708a5b..6f1ee3816 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1382,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) { if err := s.CreateNIC(id, ep); err != nil { t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ {Destination: header.IPv4EmptySubnet, NIC: id}, @@ -1652,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestShutdownConnectingSocket(t *testing.T) { + for _, test := range []struct { + name string + shutdownMode tcpip.ShutdownFlags + }{ + {"ShutdownRead", tcpip.ShutdownRead}, + {"ShutdownWrite", tcpip.ShutdownWrite}, + {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with + // the handshake process. + c.Create(-1) + + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + // Start connection attempt. + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) + } + + // Check the SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + if err := c.EP.Shutdown(test.shutdownMode); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + // The endpoint internal state is updated immediately. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + select { + case <-ch: + default: + t.Fatal("endpoint was not notified") + } + + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrConnectionReset{}) + + // If the endpoint is not properly shutdown, it'll re-attempt to connect + // by sending another ACK packet. + c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond)) + }) + } +} + func TestSynSent(t *testing.T) { for _, test := range []struct { name string @@ -1675,7 +1744,7 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } @@ -1991,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { ) // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the FIN but DON't ACK IT. checker.IPv4(t, c.GetPacket(), @@ -2007,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // Cause a RST to be generated by closing the read end now since we have // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the RST checker.IPv4(t, c.GetPacket(), @@ -2145,12 +2218,15 @@ func TestSmallReceiveBufferReadiness(t *testing.T) { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) } - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x7f\x00\x00\x01"), - PrefixLen: 8, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x7f\x00\x00\x01"), + PrefixLen: 8, + }, } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err) } { @@ -4954,13 +5030,17 @@ func makeStack() (*stack.Stack, tcpip.Error) { } for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address + number tcpip.NetworkProtocolNumber + addrWithPrefix tcpip.AddressWithPrefix }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, + {ipv4.ProtocolNumber, context.StackAddrWithPrefix}, + {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix}, } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ct.number, + AddressWithPrefix: ct.addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { return nil, err } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e55a7a32..88bb99354 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -243,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv4.ProtocolNumber, AddressWithPrefix: StackAddrWithPrefix, } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, @@ -257,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv6.ProtocolNumber, AddressWithPrefix: StackV6AddrWithPrefix, } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 5cc7a2886..d2c0963b0 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -63,5 +63,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4255457f9..077a2325a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -60,9 +60,8 @@ type endpoint struct { waiterQueue *waiter.Queue uniqueID uint64 net network.Endpoint - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -234,7 +233,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.receivedAt.UnixNano(), + Timestamp: p.receivedAt, } switch p.netProto { @@ -243,19 +242,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult cm.HasTOS = true cm.TOS = p.tos } + + if e.ops.GetReceivePacketInfo() { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } case header.IPv6ProtocolNumber: if e.ops.GetReceiveTClass() { cm.HasTClass = true // Although TClass is an 8-bit value it's read in the CMsg as a uint32. cm.TClass = uint32(p.tos) } + + if e.ops.GetIPv6ReceivePacketInfo() { + cm.HasIPv6PacketInfo = true + cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{ + NIC: p.packetInfo.NIC, + Addr: p.packetInfo.DestinationAddr, + } + } default: panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto)) } - if e.ops.GetReceivePacketInfo() { - cm.HasIPPacketInfo = true - cm.PacketInfo = p.packetInfo - } + if e.ops.GetReceiveOriginalDstAddress() { cm.HasOriginalDstAddress = true cm.OriginalDstAddress = p.destinationAddress @@ -283,7 +292,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -// +checklocks:e.mu +// +checklocksread:e.mu func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.net.State() { case transport.DatagramEndpointStateInitial: diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 554ce1de4..b3199489c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -313,6 +314,9 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo Clock: &faketime.NullClock{}, } s := stack.New(options) + // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus + // never allows ICMP messages. + s.SetICMPLimit(rate.Inf) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -323,12 +327,20 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) } s.SetRouteTable([]tcpip.Route{ @@ -1357,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { func TestReadIPPacketInfo(t *testing.T) { tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedLocalAddr tcpip.Address - expectedDestAddr tcpip.Address + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + checker func(tcpip.NICID) checker.ControlMessagesChecker }{ { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedLocalAddr: stackAddr, - expectedDestAddr: stackAddr, + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + LocalAddr: stackAddr, + DestinationAddr: stackAddr, + }) + }, }, { name: "IPv4 multicast", proto: header.IPv4ProtocolNumber, flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastAddr, - expectedDestAddr: multicastAddr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + // TODO(gvisor.dev/issue/3556): Check for a unicast address. + LocalAddr: multicastAddr, + DestinationAddr: multicastAddr, + }) + }, }, { name: "IPv4 broadcast", proto: header.IPv4ProtocolNumber, flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: broadcastAddr, - expectedDestAddr: broadcastAddr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + // TODO(gvisor.dev/issue/3556): Check for a unicast address. + LocalAddr: broadcastAddr, + DestinationAddr: broadcastAddr, + }) + }, }, { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedLocalAddr: stackV6Addr, - expectedDestAddr: stackV6Addr, + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ + NIC: id, + Addr: stackV6Addr, + }) + }, }, { name: "IPv6 multicast", proto: header.IPv6ProtocolNumber, flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastV6Addr, - expectedDestAddr: multicastV6Addr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ + NIC: id, + Addr: multicastV6Addr, + }) + }, }, } @@ -1437,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) { } } - c.ep.SocketOptions().SetReceivePacketInfo(true) + switch f := test.flow.netProto(); f { + case header.IPv4ProtocolNumber: + c.ep.SocketOptions().SetReceivePacketInfo(true) + case header.IPv6ProtocolNumber: + c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true) + default: + t.Fatalf("unhandled protocol number = %d", f) + } - testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ - NIC: 1, - LocalAddr: test.expectedLocalAddr, - DestinationAddr: test.expectedDestAddr, - })) + testRead(c, test.flow, test.checker(c.nicID)) if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) @@ -2504,8 +2525,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) |