diff options
Diffstat (limited to 'pkg/tcpip')
48 files changed, 3547 insertions, 585 deletions
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 0cac6c0a5..7908c5744 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -71,6 +71,7 @@ const ( // Values for ICMP code as defined in RFC 792. const ( + ICMPv4TTLExceeded = 0 ICMPv4PortUnreachable = 3 ICMPv4FragmentationNeeded = 4 ) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index ba80b64a8..4f367fe4c 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -17,6 +17,7 @@ package header import ( "crypto/sha256" "encoding/binary" + "fmt" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -445,3 +446,54 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { return GlobalScope, nil } } + +// InitialTempIID generates the initial temporary IID history value to generate +// temporary SLAAC addresses with. +// +// Panics if initialTempIIDHistory is not at least IIDSize bytes. +func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) { + h := sha256.New() + // h.Write never returns an error. + h.Write(seed) + var nicIDBuf [4]byte + binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID)) + h.Write(nicIDBuf[:]) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } +} + +// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an +// associated stable/permanent SLAAC address. +// +// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be +// used when generating a new temporary IID. +// +// Panics if tempIIDHistory is not at least IIDSize bytes. +func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix { + addrBytes := []byte(stableAddr) + h := sha256.New() + h.Write(tempIIDHistory) + h.Write(addrBytes[IIDOffsetInIPv6Address:]) + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + // The rightmost 64 bits of sum are saved for the next iteration. + if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } + + // The leftmost 64 bits of sum is used as the IID. + if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize { + panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize)) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: IIDOffsetInIPv6Address * 8, + } +} diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 13480687d..21581257b 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -594,3 +594,20 @@ func AddTCPOptionPadding(options []byte, offset int) int { } return paddingToAdd } + +// Acceptable checks if a segment that starts at segSeq and has length segLen is +// "acceptable" for arriving in a receive window that starts at rcvNxt and ends +// before rcvAcc, according to the table on page 26 and 69 of RFC 793. +func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { + if rcvNxt == rcvAcc { + return segLen == 0 && segSeq == rcvNxt + } + if segLen == 0 { + // rcvWnd is incremented by 1 because that is Linux's behavior despite the + // RFC. + return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) + } + // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming + // the payload, so we'll accept any payload that overlaps the receieve window. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index abe725548..aa6db9aea 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,6 +14,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/binary", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index b857ce9d0..affa1bbdf 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -44,6 +44,7 @@ import ( "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -428,30 +429,24 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } } - vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) + vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) } if pkt.Data.Size() == 0 { return rawfile.NonBlockingWrite(fd, pkt.Header.View()) } + if pkt.Header.UsedLength() == 0 { + return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + } return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) } -// WritePackets writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -// -// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD -// picked to write the packet out has to be the same for all packets in the -// list. In other words all packets in the batch should belong to the same -// flow. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - n := pkts.Len() - - mmsgHdrs := make([]rawfile.MMsgHdr, n) - i := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { +func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { + // Send a batch of packets through batchFD. + mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) + for _, pkt := range batch { var ethHdrBuf []byte iovLen := 0 if e.hdrSize > 0 { @@ -459,13 +454,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe ethHdrBuf = make([]byte, header.EthernetMinimumSize) eth := header.Ethernet(ethHdrBuf) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, + DstAddr: pkt.EgressRoute.RemoteLinkAddress, + Type: pkt.NetworkProtocolNumber, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if pkt.EgressRoute.LocalLinkAddress != "" { + ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress } else { ethHdr.SrcAddr = e.addr } @@ -473,34 +468,34 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe iovLen++ } - var vnetHdrBuf []byte vnetHdr := virtioNetHdr{} + var vnetHdrBuf []byte if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - if gso != nil { + if pkt.GSOOptions != nil { vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) - if gso.NeedsCsum { + if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM - vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen - vnetHdr.csumOffset = gso.CsumOffset + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset } - if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { - switch gso.Type { + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 case stack.GSOTCPv6: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 default: - panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) } - vnetHdr.gsoSize = gso.MSS + vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr) + vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) iovLen++ } iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) - mmsgHdr := &mmsgHdrs[i] + var mmsgHdr rawfile.MMsgHdr mmsgHdr.Msg.Iov = &iovecs[0] iovecIdx := 0 if vnetHdrBuf != nil { @@ -535,22 +530,68 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pktSize += vec.Len } mmsgHdr.Msg.Iovlen = uint64(iovecIdx) - i++ + mmsgHdrs = append(mmsgHdrs, mmsgHdr) } packets := 0 - for packets < n { - fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))] - sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) + for len(mmsgHdrs) > 0 { + sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) if err != nil { return packets, err } packets += sent mmsgHdrs = mmsgHdrs[sent:] } + return packets, nil } +// WritePackets writes outbound packets to the underlying file descriptors. If +// one is not currently writable, the packet is dropped. +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Preallocate to avoid repeated reallocation as we append to batch. + // batchSz is 47 because when SWGSO is in use then a single 65KB TCP + // segment can get split into 46 segments of 1420 bytes and a single 216 + // byte segment. + const batchSz = 47 + batch := make([]*stack.PacketBuffer, 0, batchSz) + batchFD := -1 + sentPackets := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if len(batch) == 0 { + batchFD = e.fds[pkt.Hash%uint32(len(e.fds))] + } + pktFD := e.fds[pkt.Hash%uint32(len(e.fds))] + if sendNow := pktFD != batchFD; !sendNow { + batch = append(batch, pkt) + continue + } + n, err := e.sendBatch(batchFD, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + batch = batch[:0] + batch = append(batch, pkt) + batchFD = pktFD + } + + if len(batch) != 0 { + n, err := e.sendBatch(batchFD, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + } + return sentPackets, nil +} + // viewsEqual tests whether v1 and v2 refer to the same backing bytes. func viewsEqual(vs1, vs2 []buffer.View) bool { return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go index d81858353..df14eaad1 100644 --- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go +++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go @@ -17,17 +17,7 @@ package fdbased import ( - "reflect" "unsafe" ) const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{})) - -func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) { - *(*reflect.SliceHeader)(unsafe.Pointer(&slice)) = reflect.SliceHeader{ - Data: uintptr((unsafe.Pointer(hdr))), - Len: virtioNetHdrSize, - Cap: virtioNetHdrSize, - } - return -} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD new file mode 100644 index 000000000..054c213bc --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -0,0 +1,19 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "fifo", + srcs = [ + "endpoint.go", + "packet_buffer_queue.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go new file mode 100644 index 000000000..54432194d --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -0,0 +1,209 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fifo provides the implementation of data-link layer endpoints that +// wrap another endpoint and queues all outbound packets and asynchronously +// dispatches them to the lower endpoint. +package fifo + +import ( + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// endpoint represents a LinkEndpoint which implements a FIFO queue for all +// outgoing packets. endpoint can have 1 or more underlying queueDispatchers. +// All outgoing packets are consistenly hashed to a single underlying queue +// using the PacketBuffer.Hash if set, otherwise all packets are queued to the +// first queue to avoid reordering in case of missing hash. +type endpoint struct { + dispatcher stack.NetworkDispatcher + lower stack.LinkEndpoint + wg sync.WaitGroup + dispatchers []*queueDispatcher +} + +// queueDispatcher is responsible for dispatching all outbound packets in its +// queue. It will also smartly batch packets when possible and write them +// through the lower LinkEndpoint. +type queueDispatcher struct { + lower stack.LinkEndpoint + q *packetBufferQueue + newPacketWaker sleep.Waker + closeWaker sleep.Waker +} + +// New creates a new fifo link endpoint with the n queues with maximum +// capacity of queueLen. +func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint { + e := &endpoint{ + lower: lower, + } + // Create the required dispatchers + for i := 0; i < n; i++ { + qd := &queueDispatcher{ + q: &packetBufferQueue{limit: queueLen}, + lower: lower, + } + e.dispatchers = append(e.dispatchers, qd) + e.wg.Add(1) + go func() { + defer e.wg.Done() + qd.dispatchLoop() + }() + } + return e +} + +func (q *queueDispatcher) dispatchLoop() { + const newPacketWakerID = 1 + const closeWakerID = 2 + s := sleep.Sleeper{} + s.AddWaker(&q.newPacketWaker, newPacketWakerID) + s.AddWaker(&q.closeWaker, closeWakerID) + defer s.Done() + + const batchSize = 32 + var batch stack.PacketBufferList + for { + id, ok := s.Fetch(true) + if ok && id == closeWakerID { + return + } + for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() { + batch.PushBack(pkt) + if batch.Len() < batchSize && !q.q.empty() { + continue + } + // We pass a protocol of zero here because each packet carries its + // NetworkProtocol. + q.lower.WritePackets(nil /* route */, nil /* gso */, batch, 0 /* protocol */) + for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() { + pkt.EgressRoute.Release() + batch.Remove(pkt) + } + batch.Reset() + } + } +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) +} + +// Attach implements stack.LinkEndpoint.Attach. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. +func (e *endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. +func (e *endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { + // WritePacket caller's do not set the following fields in PacketBuffer + // so we populate them here. + newRoute := r.Clone() + pkt.EgressRoute = &newRoute + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol + d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] + if !d.q.enqueue(&pkt) { + return tcpip.ErrNoBufferSpace + } + d.newPacketWaker.Assert() + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +// +// Being a batch API, each packet in pkts should have the following fields +// populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + enqueued := 0 + for pkt := pkts.Front(); pkt != nil; { + d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] + nxt := pkt.Next() + // Since qdisc can hold onto a packet for long we should Clone + // the route here to ensure it doesn't get released while the + // packet is still in our queue. + newRoute := pkt.EgressRoute.Clone() + pkt.EgressRoute = &newRoute + if !d.q.enqueue(pkt) { + if enqueued > 0 { + d.newPacketWaker.Assert() + } + return enqueued, tcpip.ErrNoBufferSpace + } + pkt = nxt + enqueued++ + d.newPacketWaker.Assert() + } + return enqueued, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return e.lower.WriteRawPacket(vv) +} + +// Wait implements stack.LinkEndpoint.Wait. +func (e *endpoint) Wait() { + e.lower.Wait() + + // The linkEP is gone. Teardown the outbound dispatcher goroutines. + for i := range e.dispatchers { + e.dispatchers[i].closeWaker.Assert() + } + + e.wg.Wait() +} diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go new file mode 100644 index 000000000..eb5abb906 --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fifo + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// packetBufferQueue is a bounded, thread-safe queue of PacketBuffers. +// +type packetBufferQueue struct { + mu sync.Mutex + list stack.PacketBufferList + limit int + used int +} + +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *packetBufferQueue) emptyLocked() bool { + return q.used == 0 +} + +// empty determines if the queue is empty. +func (q *packetBufferQueue) empty() bool { + q.mu.Lock() + r := q.emptyLocked() + q.mu.Unlock() + + return r +} + +// setLimit updates the limit. No PacketBuffers are immediately dropped in case +// the queue becomes full due to the new limit. +func (q *packetBufferQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} + +// enqueue adds the given packet to the queue. +// +// Returns true when the PacketBuffer is successfully added to the queue, in +// which case ownership of the reference is transferred to the queue. And +// returns false if the queue is full, in which case ownership is retained by +// the caller. +func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used++ + } + q.mu.Unlock() + + return r +} + +// dequeue removes and returns the next PacketBuffer from queue, if one exists. +// Ownership is transferred to the caller. +func (q *packetBufferQueue) dequeue() *stack.PacketBuffer { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used-- + } + q.mu.Unlock() + + return s +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..da1c520ae 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,8 +386,12 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) - if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) + if fragmentOffset == 0 { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) @@ -371,15 +400,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) - if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) + if fragmentOffset == 0 { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if offset > len(tcp) && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) + if offset > vv.Size() && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size()) break } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..9d0797af7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,6 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { + protocol *protocol nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache @@ -83,6 +84,11 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara return tcpip.ErrNotSupported } +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + // WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported @@ -93,7 +99,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return @@ -142,6 +151,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi return nil, tcpip.ErrBadLocalAddress } return &endpoint{ + protocol: p, nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..9db42b2a4 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -118,6 +118,11 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to // write. It assumes that the IP header is entirely in pkt.Header but does not // assume that only the IP header is in pkt.Header. It assumes that the input @@ -247,11 +252,31 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt); !ok { + if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok { // iptables is telling us to drop the packet. return nil } + if pkt.NatDone { + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + return nil + } + } + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. @@ -297,8 +322,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - dropped := ipt.CheckPackets(stack.Output, pkts) - if len(dropped) == 0 { + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r) + if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -313,6 +338,24 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := dropped[pkt]; ok { continue } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + n++ + continue + } + } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err @@ -328,7 +371,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -378,7 +425,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -394,7 +445,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt); !ok { + if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..daf1fcbc6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -416,6 +420,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + type protocol struct { // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful and it must be accessed diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 5e963a4af..f71073207 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -30,6 +30,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", "icmp_rate_limit.go", @@ -62,6 +63,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", + "//pkg/tcpip/transport/tcpconntrack", "//pkg/waiter", "@org_golang_x_time//rate:go_default_library", ], diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go new file mode 100644 index 000000000..7d1ede1f2 --- /dev/null +++ b/pkg/tcpip/stack/conntrack.go @@ -0,0 +1,480 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" +) + +// Connection tracking is used to track and manipulate packets for NAT rules. +// The connection is created for a packet if it does not exist. Every connection +// contains two tuples (original and reply). The tuples are manipulated if there +// is a matching NAT rule. The packet is modified by looking at the tuples in the +// Prerouting and Output hooks. + +// Direction of the tuple. +type ctDirection int + +const ( + dirOriginal ctDirection = iota + dirReply +) + +// Status of connection. +// TODO(gvisor.dev/issue/170): Add other states of connection. +type connStatus int + +const ( + connNew connStatus = iota + connEstablished +) + +// Manipulation type for the connection. +type manipType int + +const ( + manipDstPrerouting manipType = iota + manipDstOutput +) + +// connTrackMutable is the manipulatable part of the tuple. +type connTrackMutable struct { + // addr is source address of the tuple. + addr tcpip.Address + + // port is source port of the tuple. + port uint16 + + // protocol is network layer protocol. + protocol tcpip.NetworkProtocolNumber +} + +// connTrackImmutable is the non-manipulatable part of the tuple. +type connTrackImmutable struct { + // addr is destination address of the tuple. + addr tcpip.Address + + // direction is direction (original or reply) of the tuple. + direction ctDirection + + // port is destination port of the tuple. + port uint16 + + // protocol is transport layer protocol. + protocol tcpip.TransportProtocolNumber +} + +// connTrackTuple represents the tuple which is created from the +// packet. +type connTrackTuple struct { + // dst is non-manipulatable part of the tuple. + dst connTrackImmutable + + // src is manipulatable part of the tuple. + src connTrackMutable +} + +// connTrackTupleHolder is the container of tuple and connection. +type ConnTrackTupleHolder struct { + // conn is pointer to the connection tracking entry. + conn *connTrack + + // tuple is original or reply tuple. + tuple connTrackTuple +} + +// connTrack is the connection. +type connTrack struct { + // originalTupleHolder contains tuple in original direction. + originalTupleHolder ConnTrackTupleHolder + + // replyTupleHolder contains tuple in reply direction. + replyTupleHolder ConnTrackTupleHolder + + // status indicates connection is new or established. + status connStatus + + // timeout indicates the time connection should be active. + timeout time.Duration + + // manip indicates if the packet should be manipulated. + manip manipType + + // tcb is TCB control block. It is used to keep track of states + // of tcp connection. + tcb tcpconntrack.TCB + + // tcbHook indicates if the packet is inbound or outbound to + // update the state of tcb. + tcbHook Hook +} + +// ConnTrackTable contains a map of all existing connections created for +// NAT rules. +type ConnTrackTable struct { + // connMu protects connTrackTable. + connMu sync.RWMutex + + // connTrackTable maintains a map of tuples needed for connection tracking + // for iptables NAT rules. The key for the map is an integer calculated + // using seed, source address, destination address, source port and + // destination port. + CtMap map[uint32]ConnTrackTupleHolder + + // seed is a one-time random value initialized at stack startup + // and is used in calculation of hash key for connection tracking + // table. + Seed uint32 +} + +// parseHeaders sets headers in the packet. +func parseHeaders(pkt *PacketBuffer) { + newPkt := pkt.Clone() + + // Set network header. + hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + netHeader := header.IPv4(hdr) + newPkt.NetworkHeader = hdr + length := int(netHeader.HeaderLength()) + + // TODO(gvisor.dev/issue/170): Need to support for other + // protocols as well. + // Set transport header. + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + if newPkt.TransportHeader == nil { + h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) + if !ok { + return + } + newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) + } + case header.TCPProtocolNumber: + if newPkt.TransportHeader == nil { + h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) + if !ok { + return + } + newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) + } + } + pkt.NetworkHeader = newPkt.NetworkHeader + pkt.TransportHeader = newPkt.TransportHeader +} + +// packetToTuple converts packet to a tuple in original direction. +func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { + var tuple connTrackTuple + + netHeader := header.IPv4(pkt.NetworkHeader) + // TODO(gvisor.dev/issue/170): Need to support for other + // protocols as well. + if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + return tuple, tcpip.ErrUnknownProtocol + } + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return tuple, tcpip.ErrUnknownProtocol + } + + tuple.src.addr = netHeader.SourceAddress() + tuple.src.port = tcpHeader.SourcePort() + tuple.src.protocol = header.IPv4ProtocolNumber + + tuple.dst.addr = netHeader.DestinationAddress() + tuple.dst.port = tcpHeader.DestinationPort() + tuple.dst.protocol = netHeader.TransportProtocol() + + return tuple, nil +} + +// getReplyTuple creates reply tuple for the given tuple. +func getReplyTuple(tuple connTrackTuple) connTrackTuple { + var replyTuple connTrackTuple + replyTuple.src.addr = tuple.dst.addr + replyTuple.src.port = tuple.dst.port + replyTuple.src.protocol = tuple.src.protocol + replyTuple.dst.addr = tuple.src.addr + replyTuple.dst.port = tuple.src.port + replyTuple.dst.protocol = tuple.dst.protocol + replyTuple.dst.direction = dirReply + + return replyTuple +} + +// makeNewConn creates new connection. +func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { + var conn connTrack + conn.status = connNew + conn.originalTupleHolder.tuple = tuple + conn.originalTupleHolder.conn = &conn + conn.replyTupleHolder.tuple = replyTuple + conn.replyTupleHolder.conn = &conn + + return conn +} + +// getTupleHash returns hash of the tuple. The fields used for +// generating hash are seed (generated once for stack), source address, +// destination address, source port and destination ports. +func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { + h := jenkins.Sum32(ct.Seed) + h.Write([]byte(tuple.src.addr)) + h.Write([]byte(tuple.dst.addr)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, tuple.src.port) + h.Write([]byte(portBuf)) + binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) + h.Write([]byte(portBuf)) + + return h.Sum32() +} + +// connTrackForPacket returns connTrack for packet. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other +// transport protocols. +func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { + if hook == Prerouting { + // Headers will not be set in Prerouting. + // TODO(gvisor.dev/issue/170): Change this after parsing headers + // code is added. + parseHeaders(pkt) + } + + var dir ctDirection + tuple, err := packetToTuple(*pkt, hook) + if err != nil { + return nil, dir + } + + ct.connMu.Lock() + defer ct.connMu.Unlock() + + connTrackTable := ct.CtMap + hash := ct.getTupleHash(tuple) + + var conn *connTrack + switch createConn { + case true: + // If connection does not exist for the hash, create a new + // connection. + replyTuple := getReplyTuple(tuple) + replyHash := ct.getTupleHash(replyTuple) + newConn := makeNewConn(tuple, replyTuple) + conn = &newConn + + // Add tupleHolders to the map. + // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. + ct.CtMap[hash] = conn.originalTupleHolder + ct.CtMap[replyHash] = conn.replyTupleHolder + default: + tupleHolder, ok := connTrackTable[hash] + if !ok { + return nil, dir + } + + // If this is the reply of new connection, set the connection + // status as ESTABLISHED. + conn = tupleHolder.conn + if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { + conn.status = connEstablished + } + if tupleHolder.conn == nil { + panic("tupleHolder has null connection tracking entry") + } + + dir = tupleHolder.tuple.dst.direction + } + return conn, dir +} + +// SetNatInfo will manipulate the tuples according to iptables NAT rules. +func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { + // Get the connection. Connection is always created before this + // function is called. + conn, _ := ct.connTrackForPacket(pkt, hook, false) + if conn == nil { + panic("connection should be created to manipulate tuples.") + } + replyTuple := conn.replyTupleHolder.tuple + replyHash := ct.getTupleHash(replyTuple) + + // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to + // support changing of address for Prerouting. + + // Change the port as per the iptables rule. This tuple will be used + // to manipulate the packet in HandlePacket. + conn.replyTupleHolder.tuple.src.addr = rt.MinIP + conn.replyTupleHolder.tuple.src.port = rt.MinPort + newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + + // Add the changed tuple to the map. + ct.connMu.Lock() + defer ct.connMu.Unlock() + ct.CtMap[newHash] = conn.replyTupleHolder + if hook == Output { + conn.replyTupleHolder.conn.manip = manipDstOutput + } + + // Delete the old tuple. + delete(ct.CtMap, replyHash) +} + +// handlePacketPrerouting manipulates ports for packets in Prerouting hook. +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. +func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { + netHeader := header.IPv4(pkt.NetworkHeader) + tcpHeader := header.TCP(pkt.TransportHeader) + + // For prerouting redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. + switch dir { + case dirOriginal: + port := conn.replyTupleHolder.tuple.src.port + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + case dirReply: + port := conn.originalTupleHolder.tuple.dst.port + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + } + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// handlePacketOutput manipulates ports for packets in Output hook. +func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { + netHeader := header.IPv4(pkt.NetworkHeader) + tcpHeader := header.TCP(pkt.TransportHeader) + + // For output redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. For prerouting redirection, we only reach this point + // when replying, so packet sources are modified. + if conn.manip == manipDstOutput && dir == dirOriginal { + port := conn.replyTupleHolder.tuple.src.port + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + } else { + port := conn.originalTupleHolder.tuple.dst.port + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + } + + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + hdr := &pkt.Header + length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + if gso != nil && gso.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// HandlePacket will manipulate the port and address of the packet if the +// connection exists. +func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { + if pkt.NatDone { + return + } + + if hook != Prerouting && hook != Output { + return + } + + conn, dir := ct.connTrackForPacket(pkt, hook, false) + // Connection or Rule not found for the packet. + if conn == nil { + return + } + + netHeader := header.IPv4(pkt.NetworkHeader) + // TODO(gvisor.dev/issue/170): Need to support for other transport + // protocols as well. + if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return + } + + switch hook { + case Prerouting: + handlePacketPrerouting(pkt, conn, dir) + case Output: + handlePacketOutput(pkt, conn, gso, r, dir) + } + pkt.NatDone = true + + // Update the state of tcb. + // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle + // other tcp states. + var st tcpconntrack.Result + if conn.tcb.IsEmpty() { + conn.tcb.Init(tcpHeader) + conn.tcbHook = hook + } else { + switch hook { + case conn.tcbHook: + st = conn.tcb.UpdateStateOutbound(tcpHeader) + default: + st = conn.tcb.UpdateStateInbound(tcpHeader) + } + } + + // Delete conntrack if tcp connection is closed. + if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { + ct.deleteConnTrack(conn) + } +} + +// deleteConnTrack deletes the connection. +func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { + if conn == nil { + return + } + + tuple := conn.originalTupleHolder.tuple + hash := ct.getTupleHash(tuple) + replyTuple := conn.replyTupleHolder.tuple + replyHash := ct.getTupleHash(replyTuple) + + ct.connMu.Lock() + defer ct.connMu.Unlock() + + delete(ct.CtMap, hash) + delete(ct.CtMap, replyHash) +} diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go index 8b4213eec..d199ded6a 100644 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. +// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. package stack @@ -22,9 +22,9 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[DHCPv6NoConfiguration-0] - _ = x[DHCPv6ManagedAddress-1] - _ = x[DHCPv6OtherConfigurations-2] + _ = x[DHCPv6NoConfiguration-1] + _ = x[DHCPv6ManagedAddress-2] + _ = x[DHCPv6OtherConfigurations-3] } const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" @@ -32,8 +32,9 @@ const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAd var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} func (i DHCPv6ConfigurationFromNDPRA) String() string { + i -= 1 if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { - return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")" + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")" } return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..8084d50bc 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -89,6 +92,10 @@ func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { return f.ep.Capabilities() } +func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return f.proto.Number() +} + func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. @@ -473,7 +480,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +524,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +571,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +626,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..7c3c47d50 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -17,6 +17,7 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -110,6 +111,10 @@ func DefaultTables() IPTables { Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, + connections: ConnTrackTable{ + CtMap: make(map[uint32]ConnTrackTupleHolder), + Seed: generateRandUint32(), + }, } } @@ -173,12 +178,16 @@ const ( // dropped. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool { + // Packets are manipulated only if connection and matching + // NAT rule exists. + it.connections.HandlePacket(pkt, hook, gso, r) + // Go through each table containing the hook. for _, tablename := range it.Priorities[hook] { table := it.Tables[tablename] ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -189,7 +198,7 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { case RuleAccept: continue case RuleDrop: @@ -212,26 +221,41 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if ok := it.Check(hook, *pkt); !ok { - if drop == nil { - drop = make(map[*PacketBuffer]struct{}) + if !pkt.NatDone { + if ok := it.Check(hook, pkt, gso, r, ""); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + if pkt.NatDone { + if natPkts == nil { + natPkts = make(map[*PacketBuffer]struct{}) + } + natPkts[pkt] = struct{}{} } - drop[pkt] = struct{}{} } } - return drop + return drop, natPkts } -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a +// precondition. +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict { case RuleAccept: return chainAccept @@ -248,7 +272,7 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict { case chainAccept: return chainAccept case chainDrop: @@ -271,14 +295,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. -func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a +// precondition. +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. @@ -290,7 +321,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, *pkt, "") if hotdrop { return RuleDrop, 0 } @@ -301,7 +332,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..36cc6275d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -75,16 +75,16 @@ type RedirectTarget struct { // redirect. RangeProtoSpecified bool - // Min address used to redirect. + // MinIP indicates address used to redirect. MinIP tcpip.Address - // Max address used to redirect. + // MaxIP indicates address used to redirect. MaxIP tcpip.Address - // Min port used to redirect. + // MinPort indicates port used to redirect. MinPort uint16 - // Max port used to redirect. + // MaxPort indicates port used to redirect. MaxPort uint16 } @@ -92,50 +92,76 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { - newPkt := pkt.Clone() +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } // Set network header. - headerView := newPkt.Data.First() - netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + if hook == Prerouting { + parseHeaders(pkt) + } - hlen := int(netHeader.HeaderLength()) - tlen := int(netHeader.TotalLength()) - newPkt.Data.TrimFront(hlen) - newPkt.Data.CapLength(tlen - hlen) + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + return RuleDrop, 0 + } - // TODO(gvisor.dev/issue/170): Change destination address to - // loopback or interface address on which the packet was - // received. + // Change the address to localhost (127.0.0.1) in Output and + // to primary address of the incoming interface in Prerouting. + switch hook { + case Output: + rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) + rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + case Prerouting: + rt.MinIP = address + rt.MaxIP = address + default: + panic("redirect target is supported only on output and prerouting hooks") + } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. + netHeader := header.IPv4(pkt.NetworkHeader) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - var udpHeader header.UDP - if newPkt.TransportHeader != nil { - udpHeader = header.UDP(newPkt.TransportHeader) - } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { - return RuleDrop, 0 + udpHeader := header.UDP(pkt.TransportHeader) + udpHeader.SetDestinationPort(rt.MinPort) + + // Calculate UDP checksum and set it. + if hook == Output { + udpHeader.SetChecksum(0) + hdr := &pkt.Header + length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum := r.PseudoHeaderChecksum(protocol, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHeader.SetChecksum(0) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } - udpHeader = header.UDP(newPkt.Data.First()) } - udpHeader.SetDestinationPort(rt.MinPort) + // Change destination address. + netHeader.SetDestinationAddress(rt.MinIP) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + pkt.NatDone = true case header.TCPProtocolNumber: - var tcpHeader header.TCP - if newPkt.TransportHeader != nil { - tcpHeader = header.TCP(newPkt.TransportHeader) - } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { - return RuleDrop, 0 - } - tcpHeader = header.TCP(newPkt.TransportHeader) + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. + // Only the first packet of the connection comes here. + // Other packets will be manipulated in connection tracking. + if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { + ct.SetNatInfo(pkt, rt, hook) + ct.HandlePacket(pkt, hook, gso, r) } - // TODO(gvisor.dev/issue/170): Need to recompute checksum - // and implement nat connection tracking to support TCP. - tcpHeader.SetDestinationPort(rt.MinPort) default: return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 2ffb55f2a..1bb0ba1bd 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -82,6 +82,8 @@ type IPTables struct { // list is the order in which each table should be visited for that // hook. Priorities map[Hook][]string + + connections ConnTrackTable } // A Table defines a set of chains and hooks into the network stack. It is @@ -176,5 +178,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet PacketBuffer) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index c11d62f97..526c7d6ff 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -119,6 +119,36 @@ const ( // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so // 128 - 64 = 64. validPrefixLenForAutoGen = 64 + + // defaultAutoGenTempGlobalAddresses is the default configuration for whether + // or not to generate temporary SLAAC addresses. + defaultAutoGenTempGlobalAddresses = true + + // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 7 days (from RFC 4941 section 5). + defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour + + // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 1 day (from RFC 4941 section 5). + defaultMaxTempAddrPreferredLifetime = 24 * time.Hour + + // defaultRegenAdvanceDuration is the default duration before the deprecation + // of a temporary address when a new address will be generated. + // + // Default = 5s (from RFC 4941 section 5). + defaultRegenAdvanceDuration = 5 * time.Second + + // minRegenAdvanceDuration is the minimum duration before the deprecation + // of a temporary address when a new address will be generated. + minRegenAdvanceDuration = time.Duration(0) + + // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt + // SLAAC address regenerations in response to a NIC-local conflict. + maxSLAACAddrLocalRegenAttempts = 10 ) var ( @@ -131,6 +161,37 @@ var ( // // Min = 2hrs. MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour + + // MaxDesyncFactor is the upper bound for the preferred lifetime's desync + // factor for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Must be greater than 0. + // + // Max = 10m (from RFC 4941 section 5). + MaxDesyncFactor = 10 * time.Minute + + // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the + // maximum preferred lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be preferred for at + // least 1hr if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour + + // MinMaxTempAddrValidLifetime is the minimum value allowed for the + // maximum valid lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be valid for at least + // 2hrs if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrValidLifetime = 2 * time.Hour ) // DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an @@ -138,9 +199,11 @@ var ( type DHCPv6ConfigurationFromNDPRA int const ( + _ DHCPv6ConfigurationFromNDPRA = iota + // DHCPv6NoConfiguration indicates that no configurations are available via // DHCPv6. - DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota + DHCPv6NoConfiguration // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. // @@ -254,9 +317,6 @@ type NDPDispatcher interface { // OnDHCPv6Configuration will be called with an updated configuration that is // available via DHCPv6 for a specified NIC. // - // NDPDispatcher assumes that the initial configuration available by DHCPv6 is - // DHCPv6NoConfiguration. - // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) @@ -324,35 +384,49 @@ type NDPConfigurations struct { // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's // MAC address), then no attempt will be made to resolve the conflict. AutoGenAddressConflictRetries uint8 + + // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC + // addresses will be generated for a NIC as part of SLAAC privacy extensions, + // RFC 4941. + // + // Ignored if AutoGenGlobalAddresses is false. + AutoGenTempGlobalAddresses bool + + // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary + // SLAAC addresses. + MaxTempAddrValidLifetime time.Duration + + // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for + // temporary SLAAC addresses. + MaxTempAddrPreferredLifetime time.Duration + + // RegenAdvanceDuration is the duration before the deprecation of a temporary + // address when a new address will be generated. + RegenAdvanceDuration time.Duration } // DefaultNDPConfigurations returns an NDPConfigurations populated with // default values. func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, - MaxRtrSolicitations: defaultMaxRtrSolicitations, - RtrSolicitationInterval: defaultRtrSolicitationInterval, - MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, - HandleRAs: defaultHandleRAs, - DiscoverDefaultRouters: defaultDiscoverDefaultRouters, - DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, - AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + MaxRtrSolicitations: defaultMaxRtrSolicitations, + RtrSolicitationInterval: defaultRtrSolicitationInterval, + MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses, + MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime, + RegenAdvanceDuration: defaultRegenAdvanceDuration, } } // validate modifies an NDPConfigurations with valid values. If invalid values // are present in c, the corresponding default values will be used instead. -// -// If RetransmitTimer is less than minimumRetransmitTimer, then a value of -// defaultRetransmitTimer will be used. -// -// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then -// a value of defaultRtrSolicitationInterval will be used. -// -// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then -// a value of defaultMaxRtrSolicitationDelay will be used. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer @@ -365,6 +439,18 @@ func (c *NDPConfigurations) validate() { if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay } + + if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime { + c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime + } + + if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime { + c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime + } + + if c.RegenAdvanceDuration < minRegenAdvanceDuration { + c.RegenAdvanceDuration = minRegenAdvanceDuration + } } // ndpState is the per-interface NDP state. @@ -394,6 +480,14 @@ type ndpState struct { // The last learned DHCPv6 configuration from an NDP RA. dhcpv6Configuration DHCPv6ConfigurationFromNDPRA + + // temporaryIIDHistory is the history value used to generate a new temporary + // IID. + temporaryIIDHistory [header.IIDSize]byte + + // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for + // temporary SLAAC addresses. + temporaryAddressDesyncFactor time.Duration } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -414,7 +508,7 @@ type dadState struct { type defaultRouterState struct { // Timer to invalidate the default router. // - // May not be nil. + // Must not be nil. invalidationTimer *tcpip.CancellableTimer } @@ -424,20 +518,48 @@ type defaultRouterState struct { type onLinkPrefixState struct { // Timer to invalidate the on-link prefix. // - // May not be nil. + // Must not be nil. invalidationTimer *tcpip.CancellableTimer } +// tempSLAACAddrState holds state associated with a temporary SLAAC address. +type tempSLAACAddrState struct { + // Timer to deprecate the temporary SLAAC address. + // + // Must not be nil. + deprecationTimer *tcpip.CancellableTimer + + // Timer to invalidate the temporary SLAAC address. + // + // Must not be nil. + invalidationTimer *tcpip.CancellableTimer + + // Timer to regenerate the temporary SLAAC address. + // + // Must not be nil. + regenTimer *tcpip.CancellableTimer + + createdAt time.Time + + // The address's endpoint. + // + // Must not be nil. + ref *referencedNetworkEndpoint + + // Has a new temporary SLAAC address already been regenerated? + regenerated bool +} + // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { // Timer to deprecate the prefix. // - // May not be nil. + // Must not be nil. deprecationTimer *tcpip.CancellableTimer // Timer to invalidate the prefix. // - // May not be nil. + // Must not be nil. invalidationTimer *tcpip.CancellableTimer // Nonzero only when the address is not valid forever. @@ -446,19 +568,34 @@ type slaacPrefixState struct { // Nonzero only when the address is not preferred forever. preferredUntil time.Time - // The prefix's permanent address endpoint. - // - // May only be nil when a SLAAC address is being (re-)generated. Otherwise, - // must not be nil as all SLAAC prefixes must have a SLAAC address. - ref *referencedNetworkEndpoint + // State associated with the stable address generated for the prefix. + stableAddr struct { + // The address's endpoint. + // + // May only be nil when the address is being (re-)generated. Otherwise, + // must not be nil as all SLAAC prefixes must have a stable address. + ref *referencedNetworkEndpoint + + // The number of times an address has been generated locally where the NIC + // already had the generated address. + localGenerationFailures uint8 + } - // The number of times a permanent address has been generated for the prefix. + // The temporary (short-lived) addresses generated for the SLAAC prefix. + tempAddrs map[tcpip.Address]tempSLAACAddrState + + // The next two fields are used by both stable and temporary addresses + // generated for a SLAAC prefix. This is safe as only 1 address will be + // in the generation and DAD process at any time. That is, no two addresses + // will be generated at the same time for a given SLAAC prefix. + + // The number of times an address has been generated and added to the NIC. // // Addresses may be regenerated in reseponse to a DAD conflicts. generationAttempts uint8 - // The maximum number of times to attempt regeneration of a permanent SLAAC - // address in response to DAD conflicts. + // The maximum number of times to attempt regeneration of a SLAAC address + // in response to DAD conflicts. maxGenerationAttempts uint8 } @@ -536,10 +673,10 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. - ndp.nic.mu.Unlock() return } @@ -551,8 +688,6 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // schedule the next DAD timer. remaining-- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - - ndp.nic.mu.Unlock() return } @@ -560,15 +695,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // the last NDP NS. Either way, clean up addr's DAD state and let the // integrator know DAD has completed. delete(ndp.dad, addr) - ndp.nic.mu.Unlock() - - if err != nil { - log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) - } if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) } + + // If DAD resolved for a stable SLAAC address, attempt generation of a + // temporary SLAAC address. + if dadDone && ref.configType == slaac { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) + } }) ndp.dad[addr] = dadState{ @@ -953,9 +1091,10 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() // Check if we already maintain SLAAC state for prefix. - if _, ok := ndp.slaacPrefixes[prefix]; ok { + if state, ok := ndp.slaacPrefixes[prefix]; ok { // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. - ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl) + ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl) + ndp.slaacPrefixes[prefix] = state return } @@ -996,7 +1135,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(state.ref) + ndp.deprecateSLAACAddress(state.stableAddr.ref) }), invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] @@ -1006,6 +1145,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.invalidateSLAACPrefix(prefix, state) }), + tempAddrs: make(map[tcpip.Address]tempSLAACAddrState), maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, } @@ -1035,9 +1175,44 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { state.validUntil = now.Add(vl) } + // If the address is assigned (DAD resolved), generate a temporary address. + if state.stableAddr.ref.getKind() == permanent { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) + } + ndp.slaacPrefixes[prefix] = state } +// addSLAACAddr adds a SLAAC address to the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return nil + } + + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { + // Informed by the integrator not to add the address. + return nil + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + if err != nil { + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) + } + + return ref +} + // generateSLAACAddr generates a SLAAC address for prefix. // // Returns true if an address was successfully generated. @@ -1046,7 +1221,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.ref; r != nil { + if r := state.stableAddr.ref; r != nil { panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) } @@ -1056,68 +1231,67 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt return false } + var generatedAddr tcpip.AddressWithPrefix addrBytes := []byte(prefix.ID()) - if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addrBytes = header.AppendOpaqueInterfaceIdentifier( - addrBytes[:header.IIDOffsetInIPv6Address], - prefix, - oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), - state.generationAttempts, - oIID.SecretKey, - ) - } else if state.generationAttempts == 0 { - // Only attempt to generate an interface-specific IID if we have a valid - // link address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.nic.linkEP.LinkAddress() - if !header.IsValidUnicastEthernetAddress(linkAddr) { + + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { return false } - // Generate an address within prefix from the modified EUI-64 of ndp's NIC's - // Ethernet MAC address. - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - } else { - // We have no way to regenerate an address when addresses are not generated - // with opaque IIDs. - return false - } + dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures + if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + dadCounter, + oIID.SecretKey, + ) + } else if dadCounter == 0 { + // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if + // the DAD counter is non-zero, we cannot use this method. + // + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return false + } - generatedAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } else { + // We have no way to regenerate an address in response to an address + // conflict when addresses are not generated with opaque IIDs. + return false + } + + generatedAddr = tcpip.AddressWithPrefix{ Address: tcpip.Address(addrBytes), PrefixLen: validPrefixLenForAutoGen, - }, - } - - // If the nic already has this address, do nothing further. - if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { - return false - } + } - // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.nic.stack.ndpDisp - if ndpDisp == nil { - return false - } + if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + break + } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { - // Informed by the integrator not to add the address. - return false + state.stableAddr.localGenerationFailures++ } - deprecated := time.Since(state.preferredUntil) >= 0 - ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) - if err != nil { - panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err)) + if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { + state.stableAddr.ref = ref + state.generationAttempts++ + return true } - state.generationAttempts++ - state.ref = ref - return true + return false } // regenerateSLAACAddr regenerates an address for a SLAAC prefix. @@ -1143,24 +1317,180 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { ndp.invalidateSLAACPrefix(prefix, state) } -// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// generateTempSLAACAddr generates a new temporary SLAAC address. // -// pl is the new preferred lifetime. vl is the new valid lifetime. +// If resetGenAttempts is true, the prefix's generation counter will be reset. +// +// Returns true if a new address was generated. +func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { + // Are we configured to auto-generate new temporary global addresses for the + // prefix? + if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() { + return false + } + + if resetGenAttempts { + prefixState.generationAttempts = 0 + prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1 + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if prefixState.generationAttempts == prefixState.maxGenerationAttempts { + return false + } + + stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress + now := time.Now() + + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. + vl := ndp.configs.MaxTempAddrValidLifetime + if prefixState.validUntil != (time.Time{}) { + if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { + vl = prefixVL + } + } + + if vl <= 0 { + // Cannot create an address without a valid lifetime. + return false + } + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or the + // maximum temporary address preferred lifetime - the temporary address desync + // factor. + pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor + if prefixState.preferredUntil != (time.Time{}) { + if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { + // Respect the preferred lifetime of the prefix, as per RFC 4941 section + // 3.3 step 4. + pl = prefixPL + } + } + + // As per RFC 4941 section 3.3 step 5, a temporary address is created only if + // the calculated preferred lifetime is greater than the advance regeneration + // duration. In particular, we MUST NOT create a temporary address with a zero + // Preferred Lifetime. + if pl <= ndp.configs.RegenAdvanceDuration { + return false + } + + // Attempt to generate a new address that is not already assigned to the NIC. + var generatedAddr tcpip.AddressWithPrefix + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { + return false + } + + generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) + if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + break + } + } + + // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary + // address with a zero preferred lifetime. The checks above ensure this + // so we know the address is not deprecated. + ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) + if ref == nil { + return false + } + + state := tempSLAACAddrState{ + deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) + } + + ndp.deprecateSLAACAddress(tempAddrState.ref) + }), + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr)) + } + + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) + }), + regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr)) + } + + // If an address has already been regenerated for this address, don't + // regenerate another address. + if tempAddrState.regenerated { + return + } + + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */) + prefixState.tempAddrs[generatedAddr.Address] = tempAddrState + ndp.slaacPrefixes[prefix] = prefixState + }), + createdAt: now, + ref: ref, + } + + state.deprecationTimer.Reset(pl) + state.invalidationTimer.Reset(vl) + state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + + prefixState.generationAttempts++ + prefixState.tempAddrs[generatedAddr.Address] = state + + return true +} + +// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) { - prefixState, ok := ndp.slaacPrefixes[prefix] +func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { + state, ok := ndp.slaacPrefixes[prefix] if !ok { - panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)) + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix)) } - defer func() { ndp.slaacPrefixes[prefix] = prefixState }() + ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts) + ndp.slaacPrefixes[prefix] = state +} + +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 if deprecated { - ndp.deprecateSLAACAddress(prefixState.ref) + ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) } else { - prefixState.ref.deprecated = false + prefixState.stableAddr.ref.deprecated = false } // If prefix was preferred for some finite lifetime before, stop the @@ -1190,36 +1520,118 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim // // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. - // Handle the infinite valid lifetime separately as we do not keep a timer in - // this case. if vl >= header.NDPInfiniteLifetime { + // Handle the infinite valid lifetime separately as we do not keep a timer + // in this case. prefixState.invalidationTimer.StopLocked() prefixState.validUntil = time.Time{} - return - } + } else { + var effectiveVl time.Duration + var rl time.Duration - var effectiveVl time.Duration - var rl time.Duration + // If the prefix was originally set to be valid forever, assume the + // remaining time to be the maximum possible value. + if prefixState.validUntil == (time.Time{}) { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(prefixState.validUntil) + } - // If the prefix was originally set to be valid forever, assume the remaining - // time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { - rl = header.NDPInfiniteLifetime - } else { - rl = time.Until(prefixState.validUntil) + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl > MinPrefixInformationValidLifetimeForUpdate { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if effectiveVl != 0 { + prefixState.invalidationTimer.StopLocked() + prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) + } } - if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { - effectiveVl = vl - } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + // If DAD is not yet complete on the stable address, there is no need to do + // work with temporary addresses. + if prefixState.stableAddr.ref.getKind() != permanent { return - } else { - effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) - prefixState.validUntil = now.Add(effectiveVl) + // Note, we do not need to update the entries in the temporary address map + // after updating the timers because the timers are held as pointers. + var regenForAddr tcpip.Address + allAddressesRegenerated := true + for tempAddr, tempAddrState := range prefixState.tempAddrs { + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. Note, the valid lifetime of a + // temporary address is relative to the address's creation time. + validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) + if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { + validUntil = prefixState.validUntil + } + + // If the address is no longer valid, invalidate it immediately. Otherwise, + // reset the invalidation timer. + newValidLifetime := validUntil.Sub(now) + if newValidLifetime <= 0 { + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) + continue + } + tempAddrState.invalidationTimer.StopLocked() + tempAddrState.invalidationTimer.Reset(newValidLifetime) + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or + // the maximum temporary address preferred lifetime - the temporary address + // desync factor. Note, the preferred lifetime of a temporary address is + // relative to the address's creation time. + preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) + if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { + preferredUntil = prefixState.preferredUntil + } + + // If the address is no longer preferred, deprecate it immediately. + // Otherwise, reset the deprecation timer. + newPreferredLifetime := preferredUntil.Sub(now) + tempAddrState.deprecationTimer.StopLocked() + if newPreferredLifetime <= 0 { + ndp.deprecateSLAACAddress(tempAddrState.ref) + } else { + tempAddrState.ref.deprecated = false + tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + } + + tempAddrState.regenTimer.StopLocked() + if tempAddrState.regenerated { + } else { + allAddressesRegenerated = false + + if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration { + // The new preferred lifetime is less than the advance regeneration + // duration so regenerate an address for this temporary address + // immediately after we finish iterating over the temporary addresses. + regenForAddr = tempAddr + } else { + tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + } + } + } + + // Generate a new temporary address if all of the existing temporary addresses + // have been regenerated, or we need to immediately regenerate an address + // due to an update in preferred lifetime. + // + // If each temporay address has already been regenerated, no new temporary + // address will be generated. To ensure continuation of temporary SLAAC + // addresses, we manually try to regenerate an address here. + if len(regenForAddr) != 0 || allAddressesRegenerated { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok { + state.regenerated = true + prefixState.tempAddrs[regenForAddr] = state + } + } } // deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP @@ -1243,11 +1655,11 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.ref; r != nil { + if r := state.stableAddr.ref; r != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil { - panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err)) + if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) } } @@ -1265,14 +1677,14 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress { + if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress { return } if !invalidatePrefix { // If the prefix is not being invalidated, disassociate the address from the // prefix and do nothing further. - state.ref = nil + state.stableAddr.ref = nil ndp.slaacPrefixes[prefix] = state return } @@ -1286,11 +1698,68 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + // Invalidate all temporary addresses. + for tempAddr, tempAddrState := range state.tempAddrs { + ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) + } + + state.stableAddr.ref = nil state.deprecationTimer.StopLocked() state.invalidationTimer.StopLocked() delete(ndp.slaacPrefixes, prefix) } +// invalidateTempSLAACAddr invalidates a temporary SLAAC address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + // Since we are already invalidating the address, do not invalidate the + // address when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) + } + + ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary +// SLAAC address's resources from ndp. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + } + + if !invalidateAddr { + return + } + + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr)) + } + + tempAddrState, ok := state.tempAddrs[addr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr)) + } + + ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's +// timers and entry. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + tempAddrState.deprecationTimer.StopLocked() + tempAddrState.invalidationTimer.StopLocked() + tempAddrState.regenTimer.StopLocked() + delete(tempAddrs, tempAddr) +} + // cleanupState cleans up ndp's state. // // If hostOnly is true, then only host-specific state will be cleaned up. @@ -1338,6 +1807,8 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { if got := len(ndp.defaultRouters); got != 0 { panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) } + + ndp.dhcpv6Configuration = 0 } // startSolicitingRouters starts soliciting routers, as per RFC 4861 section @@ -1450,3 +1921,13 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitTimer.Stop() ndp.rtrSolicitTimer = nil } + +// initializeTempAddrState initializes state related to temporary SLAAC +// addresses. +func (ndp *ndpState) initializeTempAddrState() { + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) + + if MaxDesyncFactor != 0 { + ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6dd460984..b3d174cdd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1801,6 +1801,935 @@ func TestAutoGenAddr(t *testing.T) { } } +func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { + ret := "" + for _, c := range containList { + if !containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should have %s in the list of addresses\n", c) + } + } + for _, c := range notContainList { + if containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should not have %s in the list of addresses\n", c) + } + } + return ret +} + +// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when +// configured to do so as part of IPv6 Privacy Extensions. +func TestAutoGenTempAddr(t *testing.T) { + const ( + nicID = 1 + newMinVL = 5 + newMinVLDuration = newMinVL * time.Second + ) + + savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + stack.MaxDesyncFactor = time.Nanosecond + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for i, test := range tests { + i := i + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred + // lifetimes. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } + }) +} + +// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not +// generated for auto generated link-local addresses. +func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { + const nicID = 1 + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = time.Nanosecond + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } + }) +} + +// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address +// will not be generated until after DAD completes, even if a new Router +// Advertisement is received to refresh lifetimes. +func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { + const ( + nicID = 1 + dadTransmits = 1 + retransmitTimer = 2 * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = 0 + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Receive an RA to trigger SLAAC for prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // DAD on the stable address for prefix has not yet completed. Receiving a new + // RA that would refresh lifetimes should not generate a temporary SLAAC + // address for the prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + default: + } + + // Wait for DAD to complete for the stable address then expect the temporary + // address to be generated. + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } +} + +// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are +// regenerated. +func TestAutoGenTempAddrRegen(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Stop generating temporary addresses + ndpConfigs.AutoGenTempGlobalAddresses = false + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + + // Wait for all the temporary addresses to get invalidated. + tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} + invalidateAfter := newMinVLDuration - 2*regenAfter + for _, addr := range tempAddrs { + // Wait for a deprecation then invalidation event, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation timers could fire in any + // order. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we shouldn't get a deprecation + // event after. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated event = %+v", e) + case <-time.After(defaultTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event = %+v", e) + } + case <-time.After(invalidateAfter + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + + invalidateAfter = regenAfter + } + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + t.Fatal(mismatch) + } +} + +// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's +// regeneration timer gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate the prefix. + // + // A new temporary address should be generated after the regeneration + // time has passed since the prefix is deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + expectAutoGenAddrEvent(addr, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncEventTimeout): + } + + // Prefer the prefix again. + // + // A new temporary address should immediately be generated since the + // regeneration time has already passed since the last address was generated + // - this regeneration does not depend on a timer. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr2, newAddr) + + // Increase the maximum lifetimes for temporary addresses to large values + // then refresh the lifetimes of the prefix. + // + // A new address should not be generated after the regeneration time that was + // expected for the previous check. This is because the preferred lifetime for + // the temporary addresses has increased, so it will take more time to + // regenerate a new temporary address. Note, new addresses are only + // regenerated after the preferred lifetime - the regenerate advance duration + // as paased. + ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second + ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncEventTimeout): + } + + // Set the maximum lifetimes for temporary addresses such that on the next + // RA, the regeneration timer gets reset. + // + // The maximum lifetime is the sum of the minimum lifetimes for temporary + // addresses + the time that has already passed since the last address was + // generated so that the regeneration timer is needed to generate the next + // address. + newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout + ndpConfigs.MaxTempAddrValidLifetime = newLifetimes + ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) +} + +// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response +// to a mix of DAD conflicts and NIC-local conflicts. +func TestMixedSLAACAddrConflictRegen(t *testing.T) { + const ( + nicID = 1 + nicName = "nic" + lifetimeSeconds = 9999 + // From stack.maxSLAACAddrLocalRegenAttempts + maxSLAACAddrLocalRegenAttempts = 10 + // We use 2 more addreses than the maximum local regeneration attempts + // because we want to also trigger regeneration in response to a DAD + // conflicts for this test. + maxAddrs = maxSLAACAddrLocalRegenAttempts + 2 + dupAddrTransmits = 1 + retransmitTimer = time.Second + ) + + var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte + header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID) + + var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte + header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID) + + prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1) + var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix + var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix + var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix + addrBytes := []byte(subnet.ID()) + for i := 0; i < maxAddrs; i++ { + stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)), + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + // When generating temporary addresses, the resolved stable address for the + // SLAAC prefix will be the first address stable address generated for the + // prefix as we will not simulate address conflicts for the stable addresses + // in tests involving temporary addresses. Address conflicts for stable + // addresses will be done in their own tests. + tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address) + tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address) + } + + tests := []struct { + name string + addrs []tcpip.AddressWithPrefix + tempAddrs bool + initialExpect tcpip.AddressWithPrefix + nicNameFromID func(tcpip.NICID, string) string + }{ + { + name: "Stable addresses with opaque IIDs", + addrs: stableAddrsWithOpaqueIID[:], + nicNameFromID: func(tcpip.NICID, string) string { + return nicName + }, + }, + { + name: "Temporary addresses with opaque IIDs", + addrs: tempAddrsWithOpaqueIID[:], + tempAddrs: true, + initialExpect: stableAddrsWithOpaqueIID[0], + nicNameFromID: func(tcpip.NICID, string) string { + return nicName + }, + }, + { + name: "Temporary addresses with modified EUI64", + addrs: tempAddrsWithModifiedEUI64[:], + tempAddrs: true, + initialExpect: stableAddrWithModifiedEUI64, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: test.tempAddrs, + AutoGenAddressConflictRetries: 1, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: test.nicNameFromID, + }, + }) + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr2, + NIC: nicID, + }}) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + for j := 0; j < len(test.addrs)-1; j++ { + // The NIC will not attempt to generate an address in response to a + // NIC-local conflict after some maximum number of attempts. We skip + // creating a conflict for the address that would be generated as part + // of the last attempt so we can simulate a DAD conflict for this + // address and restart the NIC-local generation process. + if j == maxSLAACAddrLocalRegenAttempts-1 { + continue + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + } + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + // Enable DAD. + ndpDisp.dadC = make(chan ndpDADEvent, 2) + ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits + ndpConfigs.RetransmitTimer = retransmitTimer + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + + // Do SLAAC for prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + if test.initialExpect != (tcpip.AddressWithPrefix{}) { + expectAutoGenAddrEvent(test.initialExpect, newAddr) + expectDADEventAsync(test.initialExpect.Address) + } + + // The last local generation attempt should succeed, but we introduce a + // DAD failure to restart the local generation process. + addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] + expectAutoGenAddrAsyncEvent(addr, newAddr) + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // The last address generated should resolve DAD. + addr = test.addrs[len(test.addrs)-1] + expectAutoGenAddrAsyncEvent(addr, newAddr) + expectDADEventAsync(addr.Address) + + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + default: + } + }) + } +} + // stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, // channel.Endpoint and stack.Stack. // @@ -2196,7 +3125,6 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } @@ -2808,9 +3736,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } } -// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an -// auto-generated IPv6 address in response to a DAD conflict. -func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { +func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const nicID = 1 const nicName = "nic" const dadTransmits = 1 @@ -2818,6 +3744,13 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 + // Needed for the temporary address sub test. + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MaxDesyncFactor = time.Nanosecond + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] n, err := rand.Read(secretKey) @@ -2830,185 +3763,234 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { - for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { - addrTypes := []struct { - name string - ndpConfigs stack.NDPConfigurations - autoGenLinkLocal bool - subnet tcpip.Subnet - triggerSLAACFn func(e *channel.Endpoint) - }{ - { - name: "Global address", - ndpConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - subnet: subnet, - triggerSLAACFn: func(e *channel.Endpoint) { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix { + addrBytes := []byte(subnet.ID()) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)), + PrefixLen: 64, + } + } - }, - }, - { - name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - AutoGenAddressConflictRetries: maxRetries, - }, - autoGenLinkLocal: true, - subnet: header.IPv6LinkLocalPrefix.Subnet(), - triggerSLAACFn: func(e *channel.Endpoint) {}, - }, + expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - for _, addrType := range addrTypes { - maxRetries := maxRetries - numFailures := numFailures - addrType := addrType + expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) { - t.Parallel() + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } + expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() - addrType.triggerSLAACFn(e) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } - // Simulate DAD conflicts so the address is regenerated. - for i := uint8(0); i < numFailures; i++ { - addrBytes := []byte(addrType.subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) + stableAddrForTempAddrTest := addrForSubnet(subnet, 0) + + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + return nil - // Should not have any addresses assigned to the NIC. - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(subnet, dadCounter) + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + autoGenLinkLocal: true, + prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + return nil + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter) + }, + }, + { + name: "Temporary address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + header.InitialTempIID(tempIIDHistory, nil, nicID) + + // Generate a stable SLAAC address so temporary addresses will be + // generated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) + expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) + + // The stable address will be assigned throughout the test. + return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} + }, + addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address) + }, + }, + } + + for _, addrType := range addrTypes { + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the parallel + // tests complete and limit the number of parallel tests running at the same + // time to reduce flakes. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run(addrType.name, func(t *testing.T) { + for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { + for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { + maxRetries := maxRetries + numFailures := numFailures + addrType := addrType + + t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } - if want := (tcpip.AddressWithPrefix{}); mainAddr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := addrType.ndpConfigs + ndpConfigs.AutoGenAddressConflictRetries = maxRetries + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + var tempIIDHistory [header.IIDSize]byte + stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + + // Simulate DAD conflicts so the address is regenerated. + for i := uint8(0); i < numFailures; i++ { + addr := addrType.addrGenFn(i, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) } - default: - t.Fatal("expected DAD event") - } - // Attempting to add the address manually should not fail if the - // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) - } - if err := s.RemoveAddress(nicID, addr.Address); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) - } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) } - default: - t.Fatal("expected DAD event") - } - } + expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) + expectDADEvent(t, &ndpDisp, addr.Address, false) - // Should not have any addresses assigned to the NIC. - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); mainAddr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) - } + // Attempting to add the address manually should not fail if the + // address's state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + } + if err := s.RemoveAddress(nicID, addr.Address); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) + } + expectDADEvent(t, &ndpDisp, addr.Address, false) + } - // If we had less failures than generation attempts, we should have an - // address after DAD resolves. - if maxRetries+1 > numFailures { - addrBytes := []byte(addrType.subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)), - PrefixLen: 64, + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) } - expectAutoGenAddrEvent(addr, newAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + // If we had less failures than generation attempts, we should have + // an address after DAD resolves. + if maxRetries+1 > numFailures { + addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, &ndpDisp, addr.Address, true) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { + t.Fatal(mismatch) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for DAD event") } - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if mainAddr != addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr) + // Should not attempt address generation again. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): } - } - - // Should not attempt address regeneration again. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): - } - }) + }) + } } - } + }) } } @@ -3906,7 +4888,12 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } } - // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. + // Even if the first RA reports no DHCPv6 configurations are available, the + // dispatcher should get an event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + // Receiving the same update again should not result in an event to the + // dispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) expectNoDHCPv6Event() @@ -3914,8 +4901,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectDHCPv6Event(stack.DHCPv6OtherConfigurations) - // Receiving the same update again should not result in an event to the - // NDPDispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() @@ -3951,6 +4936,21 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectDHCPv6Event(stack.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() + + // Cycling the NIC should cause the last DHCPv6 configuration to be cleared. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() } // TestRouterSolicitation tests the initial Router Solicitations that are sent diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..8f4c1fe42 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -131,6 +131,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } + nic.mu.ndp.initializeTempAddrState() // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { @@ -451,7 +452,7 @@ type ipv6AddrCandidate struct { // primaryIPv6Endpoint returns an IPv6 endpoint following Source Address // Selection (RFC 6724 section 5). // -// Note, only rules 1-3 are followed. +// Note, only rules 1-3 and 7 are followed. // // remoteAddr must be a valid IPv6 address. func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { @@ -522,6 +523,11 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn return sbDep } + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { + return saTemp + } + // sa and sb are equal, return the endpoint that is closest to the front of // the primary endpoint list. return i < j @@ -1014,14 +1020,14 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { switch r.protocol { case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */) + return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) default: r.expireLocked() return nil } } -func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error { +func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { addr := r.addrWithPrefix() isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) @@ -1031,8 +1037,11 @@ func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, al // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. - if r.configType == slaac { - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation) + switch r.configType { + case slaac: + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case slaacTemp: + n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) } } @@ -1203,12 +1212,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1221,8 +1230,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. if protocol == header.IPv4ProtocolNumber { + // iptables filtering. ipt := n.stack.IPTables() - if ok := ipt.Check(Prerouting, pkt); !ok { + address := n.primaryAddress(protocol) + if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok { // iptables is telling us to drop the packet. return } @@ -1289,22 +1300,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1332,12 +1329,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1375,11 +1373,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } @@ -1448,12 +1447,19 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a // new address will be generated for it. - if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil { + if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { return err } - if ref.configType == slaac { - n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet()) + prefix := ref.addrWithPrefix().Subnet() + + switch ref.configType { + case slaac: + n.mu.ndp.regenerateSLAACAddr(prefix) + case slaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) } return nil @@ -1552,9 +1558,14 @@ const ( // multicast group). static networkEndpointConfigType = iota - // A slaac configured endpoint is an IPv6 endpoint that was - // added by SLAAC as per RFC 4862 section 5.5.3. + // A SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4862 section 5.5.3. slaac + + // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are + // not expected to be valid (or preferred) forever; hence the term temporary. + slaacTemp ) type referencedNetworkEndpoint struct { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..926df4d7b 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,13 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They @@ -60,6 +66,16 @@ type PacketBuffer struct { // Owner is implemented by task to get the uid and gid. // Only set for locally generated packets. Owner tcpip.PacketOwner + + // The following fields are only set by the qdisc layer when the packet + // is added to a queue. + EgressRoute *Route + GSOOptions *GSO + NetworkProtocolNumber tcpip.NetworkProtocolNumber + + // NatDone indicates if the packet has been manipulated as per NAT + // iptables rule. + NatDone bool } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 23ca9ee03..b331427c6 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -269,6 +269,10 @@ type NetworkEndpoint interface { // Close is called when the endpoint is reomved from a stack. Close() + + // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for + // this endpoint. + NetworkProtocolNumber() tcpip.NetworkProtocolNumber } // NetworkProtocol is the interface that needs to be implemented by network diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index a0e5e0300..150297ab9 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -217,6 +217,12 @@ func (r *Route) MTU() uint32 { return r.ref.ep.MTU() } +// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying +// network endpoint. +func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return r.ref.ep.NetworkProtocolNumber() +} + // Release frees all resources associated with the route. func (r *Route) Release() { if r.ref != nil { @@ -255,3 +261,16 @@ func (r *Route) MakeLoopedRoute() Route { func (r *Route) Stack() *Stack { return r.ref.stack() } + +// ReverseRoute returns new route with given source and destination address. +func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { + return Route{ + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + ref: r.ref, + Loop: r.Loop, + } +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 41398a1b6..e33fae4eb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -464,6 +464,10 @@ type Stack struct { // (IIDs) as outlined by RFC 7217. opaqueIIDOpts OpaqueInterfaceIdentifierOptions + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue @@ -541,6 +545,21 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -664,6 +683,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), } @@ -1865,3 +1885,22 @@ func generateRandInt64() int64 { } return v } + +// FindNetworkEndpoint returns the network endpoint for the given address. +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, nic := range s.nics { + id := NetworkEndpointID{address} + + if ref, ok := nic.mu.endpoints[id]; ok { + nic.mu.RLock() + defer nic.mu.RUnlock() + + // An endpoint with this id exists, check if it can be used and return it. + return ref.ep, nil + } + } + return nil, tcpip.ErrBadAddress +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..1a2cf007c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return @@ -126,6 +128,10 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } +func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return f.proto.Number() +} + func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -2870,14 +2876,25 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") nicID = 1 + lifetimeSeconds = 9999 ) + prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) + + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address + tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address + // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. tests := []struct { - name string - nicAddrs []tcpip.Address - connectAddr tcpip.Address - expectedLocalAddr tcpip.Address + name string + slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix + nicAddrs []tcpip.Address + slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix + connectAddr tcpip.Address + expectedLocalAddr tcpip.Address }{ // Test Rule 1 of RFC 6724 section 5. { @@ -2967,6 +2984,22 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { expectedLocalAddr: uniqueLocalAddr1, }, + // Test Rule 7 of RFC 6724 section 5. + { + name: "Temp Global most preferred (last address)", + slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: tempGlobalAddr1, + }, + { + name: "Temp Global most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, + connectAddr: globalAddr2, + expectedLocalAddr: tempGlobalAddr1, + }, + // Test returning the endpoint that is closest to the front when // candidate addresses are "equal" from the perspective of RFC 6724 // section 5. @@ -2988,6 +3021,13 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { connectAddr: uniqueLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, + { + name: "Temp Global for Global", + slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, + slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, + connectAddr: globalAddr1, + expectedLocalAddr: tempGlobalAddr2, + }, } for _, test := range tests { @@ -2996,6 +3036,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDispatcher{}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -3007,12 +3053,20 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { }}) s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) + } + for _, a := range test.nicAddrs { if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) } } + if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) { + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) + } + if t.Failed() { t.FailNow() } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f2aa69069..f38eb6833 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -115,7 +115,7 @@ go_test( size = "small", srcs = ["rcv_test.go"], deps = [ - ":tcp", + "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], ) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 76e27bf26..a7e088d4e 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -801,6 +801,9 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Header = buffer.NewPrependable(hdrSize) pkt.Hash = tf.txHash pkt.Owner = owner + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() data.ReadToVV(&pkt.Data, packetSize) buildTCPHdr(r, tf, &pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index a4b73b588..6fe97fefd 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -70,24 +70,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale // acceptable checks if the segment sequence number range is acceptable // according to the table on page 26 of RFC 793. func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) -} - -// Acceptable checks if a segment that starts at segSeq and has length segLen is -// "acceptable" for arriving in a receive window that starts at rcvNxt and ends -// before rcvAcc, according to the table on page 26 and 69 of RFC 793. -func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { - if rcvNxt == rcvAcc { - return segLen == 0 && segSeq == rcvNxt - } - if segLen == 0 { - // rcvWnd is incremented by 1 because that is Linux's behavior despite the - // RFC. - return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) - } - // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming - // the payload, so we'll accept any payload that overlaps the receieve window. - return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) + return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) } // getSendParams returns the parameters needed by the sender when building diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index dc02729ce..c9eeff935 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -17,8 +17,8 @@ package rcv_test import ( "testing" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) func TestAcceptable(t *testing.T) { @@ -67,8 +67,8 @@ func TestAcceptable(t *testing.T) { {105, 2, 108, 108, false}, {105, 2, 109, 109, false}, } { - if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { - t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) + if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { + t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) } } } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index d8cfe3115..a3018914b 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -41,6 +41,10 @@ const ( // nDupAckThreshold is the number of duplicate ACK's required // before fast-retransmit is entered. nDupAckThreshold = 3 + + // MaxRetries is the maximum number of probe retries sender does + // before timing out the connection, Linux default TCP_RETR2. + MaxRetries = 15 ) // ccState indicates the current congestion control state for this sender. @@ -138,6 +142,14 @@ type sender struct { // the first segment that was retransmitted due to RTO expiration. firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + // zeroWindowProbing is set if the sender is currently probing + // for zero receive window. + zeroWindowProbing bool `state:"nosave"` + + // unackZeroWindowProbes is the number of unacknowledged zero + // window probes. + unackZeroWindowProbes uint32 `state:"nosave"` + closed bool writeNext *segment writeList segmentList @@ -479,10 +491,24 @@ func (s *sender) retransmitTimerExpired() bool { remaining = uto - elapsed } - if remaining <= 0 || s.rto >= MaxRTO { + // Always honor the user-timeout irrespective of whether the zero + // window probes were acknowledged. + // net/ipv4/tcp_timer.c::tcp_probe_timer() + if remaining <= 0 || s.unackZeroWindowProbes >= MaxRetries { return false } + if s.rto >= MaxRTO { + // RFC 1122 section: 4.2.2.17 + // A TCP MAY keep its offered receive window closed + // indefinitely. As long as the receiving TCP continues to + // send acknowledgments in response to the probe segments, the + // sending TCP MUST allow the connection to stay open. + if !(s.zeroWindowProbing && s.unackZeroWindowProbes == 0) { + return false + } + } + // Set new timeout. The timer will be restarted by the call to sendData // below. s.rto *= 2 @@ -533,6 +559,15 @@ func (s *sender) retransmitTimerExpired() bool { // information is usable after an RTO. s.ep.scoreboard.Reset() s.writeNext = s.writeList.Front() + + // RFC 1122 4.2.2.17: Start sending zero window probes when we still see a + // zero receive window after retransmission interval and we have data to + // send. + if s.zeroWindowProbing { + s.sendZeroWindowProbe() + return true + } + s.sendData() return true @@ -827,6 +862,34 @@ func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) return dataSent } +func (s *sender) sendZeroWindowProbe() { + ack, win := s.ep.rcv.getSendParams() + s.unackZeroWindowProbes++ + // Send a zero window probe with sequence number pointing to + // the last acknowledged byte. + s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win) + // Rearm the timer to continue probing. + s.resendTimer.enable(s.rto) +} + +func (s *sender) enableZeroWindowProbing() { + s.zeroWindowProbing = true + // We piggyback the probing on the retransmit timer with the + // current retranmission interval, as we may start probing while + // segment retransmissions. + if s.firstRetransmittedSegXmitTime.IsZero() { + s.firstRetransmittedSegXmitTime = time.Now() + } + s.resendTimer.enable(s.rto) +} + +func (s *sender) disableZeroWindowProbing() { + s.zeroWindowProbing = false + s.unackZeroWindowProbes = 0 + s.firstRetransmittedSegXmitTime = time.Time{} + s.resendTimer.disable() +} + // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { @@ -875,6 +938,13 @@ func (s *sender) sendData() { s.ep.disableKeepaliveTimer() } + // If the sender has advertized zero receive window and we have + // data to be sent out, start zero window probing to query the + // the remote for it's receive window size. + if s.writeNext != nil && s.sndWnd == 0 { + s.enableZeroWindowProbing() + } + // Enable the timer if we have pending data and it's not enabled yet. if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { s.resendTimer.enable(s.rto) @@ -1122,8 +1192,26 @@ func (s *sender) handleRcvdSegment(seg *segment) { // Stash away the current window size. s.sndWnd = seg.window - // Ignore ack if it doesn't acknowledge any new data. ack := seg.ackNumber + + // Disable zero window probing if remote advertizes a non-zero receive + // window. This can be with an ACK to the zero window probe (where the + // acknumber refers to the already acknowledged byte) OR to any previously + // unacknowledged segment. + if s.zeroWindowProbing && seg.window > 0 && + (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { + s.disableZeroWindowProbing() + } + + // On receiving the ACK for the zero window probe, account for it and + // skip trying to send any segment as we are still probing for + // receive window to become non-zero. + if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna { + s.unackZeroWindowProbes-- + return + } + + // Ignore ack if it doesn't acknowledge any new data. if (ack - 1).InRange(s.sndUna, s.sndNxt) { s.dupAckCount = 0 @@ -1143,7 +1231,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { } // When an ack is received we must rearm the timer. - // RFC 6298 5.2 + // RFC 6298 5.3 s.resendTimer.enable(s.rto) // Remove all acknowledged data from the write list. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7e574859b..49e4ba214 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1900,7 +1900,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, -1 /* epRcvBuf */) + c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1911,8 +1911,17 @@ func TestZeroWindowSend(t *testing.T) { t.Fatalf("Write failed: %v", err) } - // Since the window is currently zero, check that no packet is received. - c.CheckNoPacket("Packet received when window is zero") + // Check if we got a zero-window probe. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) // Open up the window. Data should be received now. c.SendPacket(nil, &context.Headers{ @@ -1925,7 +1934,7 @@ func TestZeroWindowSend(t *testing.T) { }) // Check that data is received. - b := c.GetPacket() + b = c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( @@ -3556,7 +3565,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3583,7 +3592,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 2025ff757..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -9,7 +9,6 @@ go_library( deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", - "//pkg/tcpip/transport/tcp", ], ) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 30d05200f..12bc1b5b5 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -20,7 +20,6 @@ package tcpconntrack import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) // Result is returned when the state of a TCB is updated in response to an @@ -312,7 +311,7 @@ type stream struct { // the window is zero, if it's a packet with no payload and sequence number // equal to una. func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return tcp.Acceptable(segSeq, segLen, s.una, s.end) + return header.Acceptable(segSeq, segLen, s.una, s.end) } // closed determines if the stream has already been closed. This happens when @@ -338,3 +337,16 @@ func logicalLen(tcp header.TCP) seqnum.Size { } return l } + +// IsEmpty returns true if tcb is not initialized. +func (t *TCB) IsEmpty() bool { + if t.inbound != (stream{}) || t.outbound != (stream{}) { + return false + } + + if t.firstFin != nil || t.state != ResultDrop { + return false + } + + return true +} diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true |