diff options
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r-- | pkg/tcpip/network/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp_test.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/fragmentation.go | 71 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/fragmentation_test.go | 207 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/reassembler.go | 50 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/reassembler_test.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 313 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 174 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 280 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 45 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 85 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 148 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 158 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/network/testutil/testutil.go | 15 |
18 files changed, 1183 insertions, 519 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index c118a2929..b38aff0b8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -14,6 +14,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 33a4a0720..3d5c0d270 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -31,17 +31,15 @@ import ( const ( // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber - - // ProtocolAddress is the address expected by the ARP endpoint. - ProtocolAddress = tcpip.Address("arp") ) -var _ stack.AddressableEndpoint = (*endpoint)(nil) +// ARP endpoints need to implement stack.NetworkEndpoint because the stack +// considers the layer above the link-layer a network layer; the only +// facility provided by the stack to deliver packets to a layer above +// the link-layer is via stack.NetworkEndpoint.HandlePacket. var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { - stack.AddressableEndpointState - protocol *protocol // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. @@ -87,7 +85,7 @@ func (e *endpoint) Disable() { } // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. -func (e *endpoint) DefaultTTL() uint8 { +func (*endpoint) DefaultTTL() uint8 { return 0 } @@ -100,25 +98,23 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.AddressableEndpointState.Cleanup() -} +func (*endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. -func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { +func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return ProtocolNumber } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -216,9 +212,8 @@ func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber func (p *protocol) MinimumPacketSize() int { return header.ARPSize } func (p *protocol) DefaultPrefixLen() int { return 0 } -func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - h := header.ARP(v) - return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { + return "", "" } func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { @@ -228,7 +223,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L linkAddrCache: linkAddrCache, nud: nud, } - e.AddressableEndpointState.Init(e) return e } @@ -311,10 +305,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } // NewProtocol returns an ARP network protocol. -// -// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" -// address must be added to every NIC that should respond to ARP requests. See -// ProtocolAddress for more details. func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 087ee9c66..f462524c9 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -200,9 +200,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { t.Fatalf("AddAddress for ipv4 failed: %v", err) } } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("AddAddress for arp failed: %v", err) - } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -439,6 +436,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ NetProto: protocol, diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 47fb63290..d8e4a3b54 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -47,6 +47,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 936601287..c75ca7d71 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -71,16 +71,25 @@ type FragmentID struct { // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. type Fragmentation struct { - mu sync.Mutex - highLimit int - lowLimit int - reassemblers map[FragmentID]*reassembler - rList reassemblerList - size int - timeout time.Duration - blockSize uint16 - clock tcpip.Clock - releaseJob *tcpip.Job + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[FragmentID]*reassembler + rList reassemblerList + size int + timeout time.Duration + blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job + timeoutHandler TimeoutHandler +} + +// TimeoutHandler is consulted if a packet reassembly has timed out. +type TimeoutHandler interface { + // OnReassemblyTimeout will be called with the first fragment (or nil, if the + // first fragment has not been received) of a packet whose reassembly has + // timed out. + OnReassemblyTimeout(pkt *stack.PacketBuffer) } // NewFragmentation creates a new Fragmentation. @@ -97,7 +106,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -111,12 +120,13 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea } f := &Fragmentation{ - reassemblers: make(map[FragmentID]*reassembler), - highLimit: highMemoryLimit, - lowLimit: lowMemoryLimit, - timeout: reassemblingTimeout, - blockSize: blockSize, - clock: clock, + reassemblers: make(map[FragmentID]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + blockSize: blockSize, + clock: clock, + timeoutHandler: timeoutHandler, } f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) @@ -136,16 +146,8 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // proto is the protocol number marked in the fragment being processed. It has // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. -// -// releaseCB is a callback that will run when the fragment reassembly of a -// packet is complete or cancelled. releaseCB take a a boolean argument which is -// true iff the reassembly is cancelled due to timeout. releaseCB should be -// passed only with the first fragment of a packet. If more than one releaseCB -// are passed for the same packet, only the first releaseCB will be saved for -// the packet and the succeeding ones will be dropped by running them -// immediately with a false argument. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( + id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -160,10 +162,9 @@ func (f *Fragmentation) Process( return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } - if l := vv.Size(); l < int(fragmentSize) { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + if l := pkt.Data.Size(); l != int(fragmentSize) { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } - vv.CapLength(int(fragmentSize)) f.mu.Lock() r, ok := f.reassemblers[id] @@ -179,15 +180,9 @@ func (f *Fragmentation) Process( f.releaseReassemblersLocked() } } - if releaseCB != nil { - if !r.setCallback(releaseCB) { - // We got a duplicate callback. Release it immediately. - releaseCB(false /* timedOut */) - } - } f.mu.Unlock() - res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) + res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. @@ -231,7 +226,9 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) { f.size = 0 } - r.release(timedOut) // releaseCB may run. + if h := f.timeoutHandler; timedOut && h != nil { + h.OnReassemblyTimeout(r.pkt) + } } // releaseReassemblersLocked releases already-expired reassemblers, then diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 5dcd10730..3a79688a8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // reassembleTimeout is dummy timeout used for testing, where the clock never @@ -40,13 +41,19 @@ func vv(size int, pieces ...string) buffer.VectorisedView { return buffer.NewVectorisedView(size, views) } +func pkt(size int, pieces ...string) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv(size, pieces...), + }) +} + type processInput struct { id FragmentID first uint16 last uint16 more bool proto uint8 - vv buffer.VectorisedView + pkt *stack.PacketBuffer } type processOutput struct { @@ -63,8 +70,8 @@ var processTestCases = []struct { { comment: "One ID", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -74,8 +81,8 @@ var processTestCases = []struct { { comment: "Next Header protocol mismatch", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -85,10 +92,10 @@ var processTestCases = []struct { { comment: "Two IDs", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -102,17 +109,17 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, err) } if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView()) + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)", + in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView()) } if done != c.out[i].done { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", @@ -236,11 +243,11 @@ func TestReassemblingTimeout(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -257,17 +264,17 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -281,11 +288,11 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) got := f.size want := 1 @@ -327,6 +334,7 @@ func TestErrors(t *testing.T) { last: 3, more: true, data: "012", + err: ErrInvalidArgs, }, { name: "exact block size with more and too little data", @@ -376,8 +384,8 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -498,57 +506,92 @@ func TestPacketFragmenter(t *testing.T) { } } -func TestReleaseCallback(t *testing.T) { +type testTimeoutHandler struct { + pkt *stack.PacketBuffer +} + +func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + h.pkt = pkt +} + +func TestTimeoutHandler(t *testing.T) { const ( proto = 99 ) - var result int - var callbackReasonIsTimeout bool - cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } - cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + pk1 := pkt(1, "1") + pk2 := pkt(1, "2") + + type processParam struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + } tests := []struct { - name string - callbacks []func(bool) - timeout bool - wantResult int - wantCallbackReasonIsTimeout bool + name string + params []processParam + wantError bool + wantPkt *stack.PacketBuffer }{ { - name: "callback runs on release", - callbacks: []func(bool){cb1}, - timeout: false, - wantResult: 1, - wantCallbackReasonIsTimeout: false, - }, - { - name: "first callback is nil", - callbacks: []func(bool){nil, cb2}, - timeout: false, - wantResult: 2, - wantCallbackReasonIsTimeout: false, + name: "onTimeout runs", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: pk1, }, { - name: "two callbacks - first one is set", - callbacks: []func(bool){cb1, cb2}, - timeout: false, - wantResult: 1, - wantCallbackReasonIsTimeout: false, + name: "no first fragment", + params: []processParam{ + { + first: 1, + last: 1, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: nil, }, { - name: "callback runs on timeout", - callbacks: []func(bool){cb1}, - timeout: true, - wantResult: 1, - wantCallbackReasonIsTimeout: true, + name: "second pkt is ignored", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + { + first: 0, + last: 0, + more: true, + pkt: pk2, + }, + }, + wantError: false, + wantPkt: pk1, }, { - name: "no callbacks", - callbacks: []func(bool){nil}, - timeout: false, - wantResult: 0, - wantCallbackReasonIsTimeout: false, + name: "invalid args - first is greater than last", + params: []processParam{ + { + first: 1, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: true, + wantPkt: nil, }, } @@ -556,29 +599,31 @@ func TestReleaseCallback(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - result = 0 - callbackReasonIsTimeout = false + handler := &testTimeoutHandler{pkt: nil} - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - for i, cb := range test.callbacks { - _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) - if err != nil { + for _, p := range test.params { + if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { t.Errorf("f.Process error = %s", err) } } - - r, ok := f.reassemblers[id] - if !ok { - t.Fatalf("Reassemberr not found") - } - f.release(r, test.timeout) - - if result != test.wantResult { - t.Errorf("got result = %d, want = %d", result, test.wantResult) + if !test.wantError { + r, ok := f.reassemblers[id] + if !ok { + t.Fatal("Reassembler not found") + } + f.release(r, true) } - if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { - t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + switch { + case handler.pkt != nil && test.wantPkt == nil: + t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) + case handler.pkt == nil && test.wantPkt != nil: + t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) + case handler.pkt != nil && test.wantPkt != nil: + if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { + t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) + } } }) } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index c0cc0bde0..19f4920b3 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) type hole struct { @@ -41,7 +42,7 @@ type reassembler struct { heap fragHeap done bool creationTime int64 - callback func(bool) + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { @@ -79,7 +80,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -89,18 +90,20 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf // was waiting on the mutex. We don't have to do anything in this case. return buffer.VectorisedView{}, 0, false, consumed, nil } - // For IPv6, it is possible to have different Protocol values between - // fragments of a packet (because, unlike IPv4, the Protocol is not used to - // identify a fragment). In this case, only the Protocol of the first - // fragment must be used as per RFC 8200 Section 4.5. - // - // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded - // here (instead of just the protocol) because most IP options should be - // derived from the first fragment. - if first == 0 { - r.proto = proto - } if r.updateHoles(first, last, more) { + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP + // options received in the first fragment should be used - and they should + // override options from following fragments. + if first == 0 { + r.pkt = pkt + r.proto = proto + } + vv := pkt.Data // We store the incoming packet only if it filled some holes. heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) consumed = vv.Size() @@ -124,24 +127,3 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } - -func (r *reassembler) setCallback(c func(bool)) bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.callback != nil { - return false - } - r.callback = c - return true -} - -func (r *reassembler) release(timedOut bool) { - r.mu.Lock() - callback := r.callback - r.callback = nil - r.mu.Unlock() - - if callback != nil { - callback(timedOut) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index fa2a70dc8..a0a04a027 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -105,26 +105,3 @@ func TestUpdateHoles(t *testing.T) { } } } - -func TestSetCallback(t *testing.T) { - result := 0 - reasonTimeout := false - - cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } - cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } - - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - if !r.setCallback(cb1) { - t.Errorf("setCallback failed") - } - if r.setCallback(cb2) { - t.Errorf("setCallback should fail if one is already set") - } - r.release(true) - if result != 1 { - t.Errorf("got result = %d, want = 1", result) - } - if !reasonTimeout { - t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) - } -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index c7d26e14f..787399e08 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -34,16 +35,16 @@ import ( ) const ( - localIPv4Addr = "\x0a\x00\x00\x01" - remoteIPv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") + ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") + ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") + ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") + localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") + ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") nicID = 1 ) @@ -192,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu panic("not implemented") } -func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") @@ -299,6 +296,10 @@ func (t *testInterface) Enabled() bool { return !t.mu.disabled } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) setEnabled(v bool) { t.mu.Lock() defer t.mu.Unlock() @@ -558,59 +559,135 @@ func TestIPv4Send(t *testing.T) { } } -func TestIPv4Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, +func TestReceive(t *testing.T) { + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + v4 bool + epAddr tcpip.AddressWithPrefix + handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + v4: true, + epAddr: localIPv4Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const totalLen = header.IPv4MinimumSize + 30 /* payload length */ + + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + TTL: ipv4.DefaultTTL, + Protocol: 10, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if ok := parse.IPv4(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + v4: false, + epAddr: localIPv6Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const payloadLen = 30 + view := buffer.NewView(header.IPv6MinimumSize + payloadLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadLen, + NextHeader: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, + }) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, _, _, ok := parse.IPv6(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, + }, } - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: test.v4, + }, + } + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) + } + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + } else { + ep.DecRef() + } + + stat := s.Stats().IP.PacketsReceived + if got := stat.Value(); got != 0 { + t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) + } + test.handlePacket(t, ep, &nic) + if nic.testObject.dataCalls != 1 { + t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + } + if got := stat.Value(); got != 1 { + t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) + } + }) } } @@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -705,8 +778,18 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) @@ -716,7 +799,9 @@ func TestIPv4ReceiveControl(t *testing.T) { } func TestIPv4FragmentationReceive(t *testing.T) { - s := buildDummyStack(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ testObject: testObject{ @@ -774,11 +859,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - // Send first segment. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), @@ -786,7 +866,18 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) @@ -799,7 +890,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) @@ -852,61 +942,6 @@ func TestIPv6Send(t *testing.T) { } } -func TestIPv6Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } -} - func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } @@ -933,13 +968,6 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } - r, err := buildIPv6Route( - localIPv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -1018,8 +1046,17 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv6Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) @@ -1202,7 +1239,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := header.IPv4MinimumSize + ipv4Options.AllocationSize() + ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding() totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1247,7 +1284,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.AllocationSize())) + ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding())) ip.Encode(&header.IPv4Fields{ Protocol: transportProto, TTL: ipv4.DefaultTTL, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 9b5e37fee..488945226 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -90,7 +90,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { iph := header.IPv4(pkt.NetworkHeader().View()) var newOptions header.IPv4Options - if len(iph) > header.IPv4MinimumSize { + if opts := iph.Options(); len(opts) != 0 { // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip // type ICMP packets): // If a Record Route and/or Time Stamp option is received in an @@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op) + aux, tmp, err := e.processIPOptions(pkt, opts, op) if err != nil { switch { case @@ -290,6 +290,13 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in +// transit to its final destination, as per RFC 792 page 6, Time Exceeded +// Message. +type icmpReasonTTLExceeded struct{} + +func (*icmpReasonTTLExceeded) isICMPReason() {} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -342,11 +349,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi return nil } + // If we hit a TTL Exceeded error, then we know we are operating as a router. + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + // + // ... + // + // Code 0 may be received from a gateway. ... + // + // Note, Code 0 is the TTL exceeded error. + // + // If we are operating as a router/gateway, don't use the packet's destination + // address as the response's source address as we should not not own the + // destination address of a packet we are forwarding. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonTTLExceeded); ok { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } @@ -454,6 +481,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.DstUnreachable + case *icmpReasonTTLExceeded: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4TTLExceeded) + counter = sent.TimeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) @@ -483,3 +514,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards the + // datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a376cb8ec..1efe6297a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -206,12 +206,12 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s if opts, ok = params.Options.(header.IPv4Options); !ok { panic(fmt.Sprintf("want IPv4Options, got %T", params.Options)) } - hdrLen += opts.AllocationSize() + hdrLen += opts.SizeWithPadding() if hdrLen > header.IPv4MaximumHeaderSize { // Since we have no way to report an error we must either panic or create // a packet which is different to what was requested. Choose panic as this // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.AllocationSize(), header.IPv4MaximumOptionsSize)) + panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize)) } } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) @@ -260,16 +260,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt) -} -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -286,24 +283,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -374,8 +374,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -400,9 +399,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -479,16 +479,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt) + return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv4(pkt.NetworkHeader().View()) + ttl := h.TTL() + if ttl == 0 { + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 791 page 30, Time to Live, + // + // This field must be decreased at each point that the internet header + // is processed to reflect the time spent processing the datagram. + // Even if no local information is available on the time actually + // spent, the field must be decremented by 1. + newHdr.SetTTL(ttl - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -497,6 +566,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + + addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + subnet := addressEndpoint.AddressWithPrefix().Subnet() + addressEndpoint.DecRef() // There has been some confusion regarding verifying checksums. We need // just look for negative 0 (0xffff) as the checksum, as it's not possible to @@ -528,15 +612,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { stats.IP.InvalidSourceAddressesReceived.Increment() return } + pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return @@ -565,29 +650,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Set up a callback in case we need to send a Time Exceeded Message, as per - // RFC 792: - // - // If a host reassembling a fragmented datagram cannot complete the - // reassembly due to missing fragments within its time limit it discards - // the datagram, and it may send a time exceeded message. - // - // If fragment zero is not available then no time exceeded need be sent at - // all. - var releaseCB func(bool) - if start == 0 { - pkt := pkt.Clone() - releaseCB = func(timedOut bool) { - if timedOut { - _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) - } - } - } - - var ready bool - var err error proto := h.Protocol() - pkt.Data, _, ready, err = e.protocol.fragmentation.Process( + data, _, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -600,8 +664,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { start+uint16(pkt.Data.Size())-1, h.More(), proto, - pkt.Data, - releaseCB, + pkt, ) if err != nil { stats.IP.MalformedPacketsReceived.Increment() @@ -611,6 +674,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !ready { return } + pkt.Data = data // The reassembler doesn't take care of fixing up the header, so we need // to do it here. @@ -628,11 +692,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } - if len(h.Options()) != 0 { + if opts := h.Options(); len(opts) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{}) + aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{}) if err != nil { switch { case @@ -778,6 +842,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack @@ -942,13 +1007,14 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index c6e565455..4e4e1f3b4 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -103,6 +103,262 @@ func TestExcludeBroadcast(t *testing.T) { }) } +// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested +// fields when options are supplied. +func TestIPv4EncodeOptions(t *testing.T) { + tests := []struct { + name string + options header.IPv4Options + encodedOptions header.IPv4Options // reply should look like this + wantIHL int + }{ + { + name: "valid no options", + wantIHL: header.IPv4MinimumSize, + }, + { + name: "one byte options", + options: header.IPv4Options{1}, + encodedOptions: header.IPv4Options{1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "two byte options", + options: header.IPv4Options{1, 1}, + encodedOptions: header.IPv4Options{1, 1, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "three byte options", + options: header.IPv4Options{1, 1, 1}, + encodedOptions: header.IPv4Options{1, 1, 1, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "four byte options", + options: header.IPv4Options{1, 1, 1, 1}, + encodedOptions: header.IPv4Options{1, 1, 1, 1}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "five byte options", + options: header.IPv4Options{1, 1, 1, 1, 1}, + encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 8, + }, + { + name: "thirty nine byte options", + options: header.IPv4Options{ + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, + }, + encodedOptions: header.IPv4Options{ + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 0, + }, + wantIHL: header.IPv4MinimumSize + 40, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + paddedOptionLength := test.options.SizeWithPadding() + ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength) + hdr := buffer.NewPrependable(int(totalLen)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + // To check the padding works, poison the last byte of the options space. + if paddedOptionLength != len(test.options) { + ip.SetHeaderLength(uint8(ipHeaderLength)) + ip.Options()[paddedOptionLength-1] = 0xff + ip.SetHeaderLength(0) + } + ip.Encode(&header.IPv4Fields{ + Options: test.options, + }) + options := ip.Options() + wantOptions := test.encodedOptions + if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { + t.Errorf("got IHL of %d, want %d", got, want) + } + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(wantOptions) == 0 && len(options) == 0 { + return + } + + if diff := cmp.Diff(wantOptions, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv4Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + } + ipv4Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + } + remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) + remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: false, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} + if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + } + + e2 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} + if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv4Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + } + + totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: test.TTL, + SrcAddr: remoteIPv4Addr1, + DstAddr: remoteIPv4Addr2, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv4Addr1.Address), + checker.DstAddr(remoteIPv4Addr1), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4TTLExceeded), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo packet through outgoing NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv4Addr1), + checker.DstAddr(remoteIPv4Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} + // TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and // checks the response. func TestIPv4Sanity(t *testing.T) { @@ -197,6 +453,14 @@ func TestIPv4Sanity(t *testing.T) { replyOptions: header.IPv4Options{1, 1, 0, 0}, }, { + name: "Check option padding", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{1, 1, 1}, + replyOptions: header.IPv4Options{1, 1, 1, 0}, + }, + { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, maxTotalLength: ipv4.MaxTotalSize, @@ -599,9 +863,10 @@ func TestIPv4Sanity(t *testing.T) { }, }) - ipHeaderLength := header.IPv4MinimumSize + test.options.AllocationSize() + paddedOptionLength := test.options.SizeWithPadding() + ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) @@ -618,6 +883,12 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } + // To check the padding works, poison the options space. + if paddedOptionLength != len(test.options) { + ip.SetHeaderLength(uint8(ipHeaderLength)) + ip.Options()[paddedOptionLength-1] = 0x01 + } + ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: test.transportProtocol, @@ -732,7 +1003,7 @@ func TestIPv4Sanity(t *testing.T) { } // If the IP options change size then the packet will change size, so // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - len(test.options) + sizeChange := len(test.replyOptions) - paddedOptionLength checker.IPv4(t, replyIPHeader, checker.IPv4HeaderLength(ipHeaderLength+sizeChange), @@ -2441,9 +2712,6 @@ func TestPacketQueing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 8502b848c..beb8f562e 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -750,6 +750,12 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in +// transit to its final destination, as per RFC 4443 section 3.3. +type icmpReasonHopLimitExceeded struct{} + +func (*icmpReasonHopLimitExceeded) isICMPReason() {} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -794,11 +800,27 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi return nil } + // If we hit a Hop Limit Exceeded error, then we know we are operating as a + // router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + // + // If we are operating as a router, do not use the packet's destination + // address as the response's source address as we should not own the + // destination address of a packet we are forwarding. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonHopLimitExceeded); ok { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } @@ -811,8 +833,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi return nil } - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. // Unfortunately at this time ICMP Packets do not have a transport @@ -830,6 +850,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } } + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + // As per RFC 4443 section 2.4 // // (c) Every ICMPv6 error message (type < 128) MUST include @@ -873,6 +895,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonHopLimitExceeded: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) + counter = sent.TimeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) @@ -896,3 +922,16 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message as per RFC 2460 Section + // 4.5: + // + // If the first fragment (i.e., the one with a Fragment Offset of zero) has + // been received, an ICMP Time Exceeded -- Fragment Reassembly Time Exceeded + // message should be sent to the source of that fragment. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 76013daa1..9bc02d851 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ NetProto: protocol, @@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -206,11 +205,16 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -279,10 +283,9 @@ func TestICMPCounts(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -290,7 +293,7 @@ func TestICMPCounts(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -317,13 +320,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -349,11 +347,16 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -422,10 +425,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -433,7 +435,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -1775,17 +1777,19 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() - - // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. - r.LocalAddress = test.destination icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.IPv6MinimumSize, Data: buffer.View(icmp).ToVectorisedView(), @@ -1795,10 +1799,9 @@ func TestCallsToNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.RemoteAddress, - DstAddr: r.LocalAddress, + SrcAddr: test.source, + DstAddr: test.destination, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0526190cc..7a00f6314 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt, params.Protocol) -} -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -640,16 +639,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt, proto) + return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv6(pkt.NetworkHeader().View()) + hopLimit := h.HopLimit() + if hopLimit <= 1 { + // As per RFC 4443 section 3.3, + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 8200 section 3, + // + // Hop Limit 8-bit unsigned integer. Decremented by 1 by + // each node that forwards the packet. + newHdr.SetHopLimit(hopLimit - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -669,6 +737,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + addressEndpoint.DecRef() + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. @@ -681,8 +761,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return @@ -888,18 +967,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Set up a callback in case we need to send a Time Exceeded Message as - // per RFC 2460 Section 4.5. - var releaseCB func(bool) - if start == 0 { - pkt := pkt.Clone() - releaseCB = func(timedOut bool) { - if timedOut { - _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) - } - } - } - // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. data, proto, ready, err := e.protocol.fragmentation.Process( @@ -914,17 +981,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { start+uint16(fragmentPayloadLen)-1, extHdr.More(), uint8(rawPayload.Identifier), - rawPayload.Buf, - releaseCB, + pkt, ) if err != nil { stats.IP.MalformedPacketsReceived.Increment() stats.IP.MalformedFragmentsReceived.Increment() return } - pkt.Data = data if ready { + pkt.Data = data + // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC // 8200 section 4.5. We also use the NextHeader value from the first @@ -1335,6 +1402,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack @@ -1590,10 +1658,9 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ - stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), - ids: ids, - hashIV: hashIV, + stack: s, + ids: ids, + hashIV: hashIV, ndpDisp: opts.NDPDisp, ndpConfigs: opts.NDPConfigs, @@ -1601,6 +1668,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { tempIIDSeed: opts.TempIIDSeed, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, } + p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[*endpoint]struct{}) p.SetDefaultTTL(DefaultTTL) return p diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1bfcdde25..a671d4bac 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "fmt" "math" + "net" "testing" "github.com/google/go-cmp/cmp" @@ -2821,3 +2822,160 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv6Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + } + ipv6Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) + remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: true, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "TTL of three", + TTL: 3, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} + if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} + if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv6Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv6Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + } + + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) + icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv6EchoRequest) + icmp.SetCode(header.ICMPv6UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: test.TTL, + SrcAddr: remoteIPv6Addr1, + DstAddr: remoteIPv6Addr2, + }) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv6Addr1.Address), + checker.DstAddr(remoteIPv6Addr1), + checker.TTL(DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), + checker.ICMPv6Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo Request packet through outgoing NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv6Addr1), + checker.DstAddr(remoteIPv6Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest), + checker.ICMPv6Code(header.ICMPv6UnusedCode), + checker.ICMPv6Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 981d1371a..37e8b1083 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -73,6 +69,17 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig } t.Cleanup(ep.Close) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := llladdr.WithPrefix() + if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + addressEP.DecRef() + } + return s, ep } @@ -961,22 +968,17 @@ func TestNDPValidation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() // Create a stack with the assigned link-local address lladdr0 // and an endpoint to lladdr1. s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r + return s, ep } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { nextHdr := uint8(header.ICMPv6ProtocolNumber) var extensions buffer.View if atomicFragment { @@ -994,13 +996,12 @@ func TestNDPValidation(t *testing.T) { PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -1114,8 +1115,7 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + s, ep := setup(t) if isRouter { // Enabling forwarding makes the stack act as a router. @@ -1131,7 +1131,7 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -1152,7 +1152,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 7cc52985e..5c3363759 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st return n, nil } -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - - return nil -} - // Attach implements LinkEndpoint.Attach. func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} |