summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go34
-rw-r--r--pkg/tcpip/network/arp/arp_test.go7
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD1
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go71
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go207
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go50
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go23
-rw-r--r--pkg/tcpip/network/ip_test.go313
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go52
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go174
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go280
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go45
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go85
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go148
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go158
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go38
-rw-r--r--pkg/tcpip/network/testutil/testutil.go15
18 files changed, 1183 insertions, 519 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index c118a2929..b38aff0b8 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -14,6 +14,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 33a4a0720..3d5c0d270 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -31,17 +31,15 @@ import (
const (
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
-
- // ProtocolAddress is the address expected by the ARP endpoint.
- ProtocolAddress = tcpip.Address("arp")
)
-var _ stack.AddressableEndpoint = (*endpoint)(nil)
+// ARP endpoints need to implement stack.NetworkEndpoint because the stack
+// considers the layer above the link-layer a network layer; the only
+// facility provided by the stack to deliver packets to a layer above
+// the link-layer is via stack.NetworkEndpoint.HandlePacket.
var _ stack.NetworkEndpoint = (*endpoint)(nil)
type endpoint struct {
- stack.AddressableEndpointState
-
protocol *protocol
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
@@ -87,7 +85,7 @@ func (e *endpoint) Disable() {
}
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
-func (e *endpoint) DefaultTTL() uint8 {
+func (*endpoint) DefaultTTL() uint8 {
return 0
}
@@ -100,25 +98,23 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.AddressableEndpointState.Cleanup()
-}
+func (*endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
-func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return ProtocolNumber
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -216,9 +212,8 @@ func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
func (p *protocol) DefaultPrefixLen() int { return 0 }
-func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- h := header.ARP(v)
- return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
+ return "", ""
}
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
@@ -228,7 +223,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
linkAddrCache: linkAddrCache,
nud: nud,
}
- e.AddressableEndpointState.Init(e)
return e
}
@@ -311,10 +305,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
}
// NewProtocol returns an ARP network protocol.
-//
-// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
-// address must be added to every NIC that should respond to ARP requests. See
-// ProtocolAddress for more details.
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{stack: s}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 087ee9c66..f462524c9 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -200,9 +200,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
}
- if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("AddAddress for arp failed: %v", err)
- }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -439,6 +436,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
NetProto: protocol,
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 47fb63290..d8e4a3b54 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -47,6 +47,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
"//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 936601287..c75ca7d71 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -71,16 +71,25 @@ type FragmentID struct {
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
type Fragmentation struct {
- mu sync.Mutex
- highLimit int
- lowLimit int
- reassemblers map[FragmentID]*reassembler
- rList reassemblerList
- size int
- timeout time.Duration
- blockSize uint16
- clock tcpip.Clock
- releaseJob *tcpip.Job
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[FragmentID]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+ blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
+ timeoutHandler TimeoutHandler
+}
+
+// TimeoutHandler is consulted if a packet reassembly has timed out.
+type TimeoutHandler interface {
+ // OnReassemblyTimeout will be called with the first fragment (or nil, if the
+ // first fragment has not been received) of a packet whose reassembly has
+ // timed out.
+ OnReassemblyTimeout(pkt *stack.PacketBuffer)
}
// NewFragmentation creates a new Fragmentation.
@@ -97,7 +106,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -111,12 +120,13 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
}
f := &Fragmentation{
- reassemblers: make(map[FragmentID]*reassembler),
- highLimit: highMemoryLimit,
- lowLimit: lowMemoryLimit,
- timeout: reassemblingTimeout,
- blockSize: blockSize,
- clock: clock,
+ reassemblers: make(map[FragmentID]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ blockSize: blockSize,
+ clock: clock,
+ timeoutHandler: timeoutHandler,
}
f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
@@ -136,16 +146,8 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
// proto is the protocol number marked in the fragment being processed. It has
// to be given here outside of the FragmentID struct because IPv6 should not use
// the protocol to identify a fragment.
-//
-// releaseCB is a callback that will run when the fragment reassembly of a
-// packet is complete or cancelled. releaseCB take a a boolean argument which is
-// true iff the reassembly is cancelled due to timeout. releaseCB should be
-// passed only with the first fragment of a packet. If more than one releaseCB
-// are passed for the same packet, only the first releaseCB will be saved for
-// the packet and the succeeding ones will be dropped by running them
-// immediately with a false argument.
func (f *Fragmentation) Process(
- id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) (
+ id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
buffer.VectorisedView, uint8, bool, error) {
if first > last {
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
@@ -160,10 +162,9 @@ func (f *Fragmentation) Process(
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
}
- if l := vv.Size(); l < int(fragmentSize) {
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ if l := pkt.Data.Size(); l != int(fragmentSize) {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
}
- vv.CapLength(int(fragmentSize))
f.mu.Lock()
r, ok := f.reassemblers[id]
@@ -179,15 +180,9 @@ func (f *Fragmentation) Process(
f.releaseReassemblersLocked()
}
}
- if releaseCB != nil {
- if !r.setCallback(releaseCB) {
- // We got a duplicate callback. Release it immediately.
- releaseCB(false /* timedOut */)
- }
- }
f.mu.Unlock()
- res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
+ res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
@@ -231,7 +226,9 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) {
f.size = 0
}
- r.release(timedOut) // releaseCB may run.
+ if h := f.timeoutHandler; timedOut && h != nil {
+ h.OnReassemblyTimeout(r.pkt)
+ }
}
// releaseReassemblersLocked releases already-expired reassemblers, then
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 5dcd10730..3a79688a8 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// reassembleTimeout is dummy timeout used for testing, where the clock never
@@ -40,13 +41,19 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
return buffer.NewVectorisedView(size, views)
}
+func pkt(size int, pieces ...string) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv(size, pieces...),
+ })
+}
+
type processInput struct {
id FragmentID
first uint16
last uint16
more bool
proto uint8
- vv buffer.VectorisedView
+ pkt *stack.PacketBuffer
}
type processOutput struct {
@@ -63,8 +70,8 @@ var processTestCases = []struct {
{
comment: "One ID",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -74,8 +81,8 @@ var processTestCases = []struct {
{
comment: "Next Header protocol mismatch",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -85,10 +92,10 @@ var processTestCases = []struct {
{
comment: "Two IDs",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -102,17 +109,17 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
if err != nil {
- t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
- in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
+ t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
}
if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)",
- in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView())
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView())
}
if done != c.out[i].done {
t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
@@ -236,11 +243,11 @@ func TestReassemblingTimeout(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock := faketime.NewManualClock()
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil)
for _, event := range test.events {
clock.Advance(event.clockAdvance)
if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil)
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data))
if err != nil {
t.Fatalf("%s: f.Process failed: %s", event.name, err)
}
@@ -257,17 +264,17 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil)
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1"))
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil)
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil)
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3"))
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -281,11 +288,11 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
got := f.size
want := 1
@@ -327,6 +334,7 @@ func TestErrors(t *testing.T) {
last: 3,
more: true,
data: "012",
+ err: ErrInvalidArgs,
},
{
name: "exact block size with more and too little data",
@@ -376,8 +384,8 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil)
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil)
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
}
@@ -498,57 +506,92 @@ func TestPacketFragmenter(t *testing.T) {
}
}
-func TestReleaseCallback(t *testing.T) {
+type testTimeoutHandler struct {
+ pkt *stack.PacketBuffer
+}
+
+func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ h.pkt = pkt
+}
+
+func TestTimeoutHandler(t *testing.T) {
const (
proto = 99
)
- var result int
- var callbackReasonIsTimeout bool
- cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut }
- cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut }
+ pk1 := pkt(1, "1")
+ pk2 := pkt(1, "2")
+
+ type processParam struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ }
tests := []struct {
- name string
- callbacks []func(bool)
- timeout bool
- wantResult int
- wantCallbackReasonIsTimeout bool
+ name string
+ params []processParam
+ wantError bool
+ wantPkt *stack.PacketBuffer
}{
{
- name: "callback runs on release",
- callbacks: []func(bool){cb1},
- timeout: false,
- wantResult: 1,
- wantCallbackReasonIsTimeout: false,
- },
- {
- name: "first callback is nil",
- callbacks: []func(bool){nil, cb2},
- timeout: false,
- wantResult: 2,
- wantCallbackReasonIsTimeout: false,
+ name: "onTimeout runs",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
},
{
- name: "two callbacks - first one is set",
- callbacks: []func(bool){cb1, cb2},
- timeout: false,
- wantResult: 1,
- wantCallbackReasonIsTimeout: false,
+ name: "no first fragment",
+ params: []processParam{
+ {
+ first: 1,
+ last: 1,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: nil,
},
{
- name: "callback runs on timeout",
- callbacks: []func(bool){cb1},
- timeout: true,
- wantResult: 1,
- wantCallbackReasonIsTimeout: true,
+ name: "second pkt is ignored",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk2,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
},
{
- name: "no callbacks",
- callbacks: []func(bool){nil},
- timeout: false,
- wantResult: 0,
- wantCallbackReasonIsTimeout: false,
+ name: "invalid args - first is greater than last",
+ params: []processParam{
+ {
+ first: 1,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: true,
+ wantPkt: nil,
},
}
@@ -556,29 +599,31 @@ func TestReleaseCallback(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- result = 0
- callbackReasonIsTimeout = false
+ handler := &testTimeoutHandler{pkt: nil}
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler)
- for i, cb := range test.callbacks {
- _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb)
- if err != nil {
+ for _, p := range test.params {
+ if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError {
t.Errorf("f.Process error = %s", err)
}
}
-
- r, ok := f.reassemblers[id]
- if !ok {
- t.Fatalf("Reassemberr not found")
- }
- f.release(r, test.timeout)
-
- if result != test.wantResult {
- t.Errorf("got result = %d, want = %d", result, test.wantResult)
+ if !test.wantError {
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatal("Reassembler not found")
+ }
+ f.release(r, true)
}
- if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout {
- t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout)
+ switch {
+ case handler.pkt != nil && test.wantPkt == nil:
+ t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView())
+ case handler.pkt == nil && test.wantPkt != nil:
+ t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView())
+ case handler.pkt != nil && test.wantPkt != nil:
+ if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" {
+ t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
+ }
}
})
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index c0cc0bde0..19f4920b3 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
type hole struct {
@@ -41,7 +42,7 @@ type reassembler struct {
heap fragHeap
done bool
creationTime int64
- callback func(bool)
+ pkt *stack.PacketBuffer
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
@@ -79,7 +80,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -89,18 +90,20 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf
// was waiting on the mutex. We don't have to do anything in this case.
return buffer.VectorisedView{}, 0, false, consumed, nil
}
- // For IPv6, it is possible to have different Protocol values between
- // fragments of a packet (because, unlike IPv4, the Protocol is not used to
- // identify a fragment). In this case, only the Protocol of the first
- // fragment must be used as per RFC 8200 Section 4.5.
- //
- // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded
- // here (instead of just the protocol) because most IP options should be
- // derived from the first fragment.
- if first == 0 {
- r.proto = proto
- }
if r.updateHoles(first, last, more) {
+ // For IPv6, it is possible to have different Protocol values between
+ // fragments of a packet (because, unlike IPv4, the Protocol is not used to
+ // identify a fragment). In this case, only the Protocol of the first
+ // fragment must be used as per RFC 8200 Section 4.5.
+ //
+ // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP
+ // options received in the first fragment should be used - and they should
+ // override options from following fragments.
+ if first == 0 {
+ r.pkt = pkt
+ r.proto = proto
+ }
+ vv := pkt.Data
// We store the incoming packet only if it filled some holes.
heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
consumed = vv.Size()
@@ -124,24 +127,3 @@ func (r *reassembler) checkDoneOrMark() bool {
r.mu.Unlock()
return prev
}
-
-func (r *reassembler) setCallback(c func(bool)) bool {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.callback != nil {
- return false
- }
- r.callback = c
- return true
-}
-
-func (r *reassembler) release(timedOut bool) {
- r.mu.Lock()
- callback := r.callback
- r.callback = nil
- r.mu.Unlock()
-
- if callback != nil {
- callback(timedOut)
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index fa2a70dc8..a0a04a027 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -105,26 +105,3 @@ func TestUpdateHoles(t *testing.T) {
}
}
}
-
-func TestSetCallback(t *testing.T) {
- result := 0
- reasonTimeout := false
-
- cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut }
- cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut }
-
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- if !r.setCallback(cb1) {
- t.Errorf("setCallback failed")
- }
- if r.setCallback(cb2) {
- t.Errorf("setCallback should fail if one is already set")
- }
- r.release(true)
- if result != 1 {
- t.Errorf("got result = %d, want = 1", result)
- }
- if !reasonTimeout {
- t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout)
- }
-}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index c7d26e14f..787399e08 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -34,16 +35,16 @@ import (
)
const (
- localIPv4Addr = "\x0a\x00\x00\x01"
- remoteIPv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02")
+ ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00")
+ ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00")
+ ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03")
+ localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
+ ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00")
+ ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
nicID = 1
)
@@ -192,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu
panic("not implemented")
}
-func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
@@ -299,6 +296,10 @@ func (t *testInterface) Enabled() bool {
return !t.mu.disabled
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) setEnabled(v bool) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -558,59 +559,135 @@ func TestIPv4Send(t *testing.T) {
}
}
-func TestIPv4Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
+func TestReceive(t *testing.T) {
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ v4 bool
+ epAddr tcpip.AddressWithPrefix
+ handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ v4: true,
+ epAddr: localIPv4Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const totalLen = header.IPv4MinimumSize + 30 /* payload length */
+
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ TTL: ipv4.DefaultTTL,
+ Protocol: 10,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if ok := parse.IPv4(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
},
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ v4: false,
+ epAddr: localIPv6Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const payloadLen = 30
+ view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: payloadLen,
+ NextHeader: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
+ })
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
- totalLen := header.IPv4MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, _, _, ok := parse.IPv6(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
+ },
}
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: test.v4,
+ },
+ }
+ ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject)
+ defer ep.Close()
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
+ }
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ } else {
+ ep.DecRef()
+ }
+
+ stat := s.Stats().IP.PacketsReceived
+ if got := stat.Value(); got != 0 {
+ t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ test.handlePacket(t, ep, &nic)
+ if nic.testObject.dataCalls != 1 {
+ t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ }
+ if got := stat.Value(); got != 1 {
+ t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ })
}
}
@@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -705,8 +778,18 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.typ = c.expectedTyp
nic.testObject.extra = c.expectedExtra
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
@@ -716,7 +799,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
func TestIPv4FragmentationReceive(t *testing.T) {
- s := buildDummyStack(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
testObject: testObject{
@@ -774,11 +859,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
nic.testObject.dstAddr = localIPv4Addr
nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
// Send first segment.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag1.ToVectorisedView(),
@@ -786,7 +866,18 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
+
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
@@ -799,7 +890,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
@@ -852,61 +942,6 @@ func TestIPv6Send(t *testing.T) {
}
}
-func TestIPv6Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- totalLen := header.IPv6MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv6MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
-
- r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
- }
-}
-
func TestIPv6ReceiveControl(t *testing.T) {
newUint16 := func(v uint16) *uint16 { return &v }
@@ -933,13 +968,6 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
- r, err := buildIPv6Route(
- localIPv6Addr,
- "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
- )
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -1018,8 +1046,17 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv6Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
@@ -1202,7 +1239,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := header.IPv4MinimumSize + ipv4Options.AllocationSize()
+ ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding()
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1247,7 +1284,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.AllocationSize()))
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding()))
ip.Encode(&header.IPv4Fields{
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 9b5e37fee..488945226 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -90,7 +90,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
iph := header.IPv4(pkt.NetworkHeader().View())
var newOptions header.IPv4Options
- if len(iph) > header.IPv4MinimumSize {
+ if opts := iph.Options(); len(opts) != 0 {
// RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
// type ICMP packets):
// If a Record Route and/or Time Stamp option is received in an
@@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
} else {
op = &optionUsageReceive{}
}
- aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op)
+ aux, tmp, err := e.processIPOptions(pkt, opts, op)
if err != nil {
switch {
case
@@ -290,6 +290,13 @@ type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in
+// transit to its final destination, as per RFC 792 page 6, Time Exceeded
+// Message.
+type icmpReasonTTLExceeded struct{}
+
+func (*icmpReasonTTLExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -342,11 +349,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
+ // If we hit a TTL Exceeded error, then we know we are operating as a router.
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ //
+ // ...
+ //
+ // Code 0 may be received from a gateway. ...
+ //
+ // Note, Code 0 is the TTL exceeded error.
+ //
+ // If we are operating as a router/gateway, don't use the packet's destination
+ // address as the response's source address as we should not not own the
+ // destination address of a packet we are forwarding.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonTTLExceeded); ok {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -454,6 +481,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonTTLExceeded:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4TTLExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
@@ -483,3 +514,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
counter.Increment()
return nil
}
+
+// OnReassemblyTimeout implements fragmentation.TimeoutHandler.
+func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ // OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792:
+ //
+ // If a host reassembling a fragmented datagram cannot complete the
+ // reassembly due to missing fragments within its time limit it discards the
+ // datagram, and it may send a time exceeded message.
+ //
+ // If fragment zero is not available then no time exceeded need be sent at
+ // all.
+ if pkt != nil {
+ p.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a376cb8ec..1efe6297a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -206,12 +206,12 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
if opts, ok = params.Options.(header.IPv4Options); !ok {
panic(fmt.Sprintf("want IPv4Options, got %T", params.Options))
}
- hdrLen += opts.AllocationSize()
+ hdrLen += opts.SizeWithPadding()
if hdrLen > header.IPv4MaximumHeaderSize {
// Since we have no way to report an error we must either panic or create
// a packet which is different to what was requested. Choose panic as this
// would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.AllocationSize(), header.IPv4MaximumOptionsSize))
+ panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize))
}
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
@@ -260,16 +260,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt)
-}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -286,24 +283,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -374,8 +374,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -400,9 +399,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -479,16 +479,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt)
+ return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ ttl := h.TTL()
+ if ttl == 0 {
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ }
+
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 791 page 30, Time to Live,
+ //
+ // This field must be decreased at each point that the internet header
+ // is processed to reflect the time spent processing the datagram.
+ // Even if no local information is available on the time actually
+ // spent, the field must be decremented by 1.
+ newHdr.SetTTL(ttl - 1)
+
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(newHdr).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -497,6 +566,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
stats.IP.MalformedPacketsReceived.Increment()
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
+ addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ addressEndpoint.DecRef()
// There has been some confusion regarding verifying checksums. We need
// just look for negative 0 (0xffff) as the checksum, as it's not possible to
@@ -528,15 +612,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+ pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
@@ -565,29 +650,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Set up a callback in case we need to send a Time Exceeded Message, as per
- // RFC 792:
- //
- // If a host reassembling a fragmented datagram cannot complete the
- // reassembly due to missing fragments within its time limit it discards
- // the datagram, and it may send a time exceeded message.
- //
- // If fragment zero is not available then no time exceeded need be sent at
- // all.
- var releaseCB func(bool)
- if start == 0 {
- pkt := pkt.Clone()
- releaseCB = func(timedOut bool) {
- if timedOut {
- _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
- }
- }
- }
-
- var ready bool
- var err error
proto := h.Protocol()
- pkt.Data, _, ready, err = e.protocol.fragmentation.Process(
+ data, _, ready, err := e.protocol.fragmentation.Process(
// As per RFC 791 section 2.3, the identification value is unique
// for a source-destination pair and protocol.
fragmentation.FragmentID{
@@ -600,8 +664,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
start+uint16(pkt.Data.Size())-1,
h.More(),
proto,
- pkt.Data,
- releaseCB,
+ pkt,
)
if err != nil {
stats.IP.MalformedPacketsReceived.Increment()
@@ -611,6 +674,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !ready {
return
}
+ pkt.Data = data
// The reassembler doesn't take care of fixing up the header, so we need
// to do it here.
@@ -628,11 +692,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.handleICMP(pkt)
return
}
- if len(h.Options()) != 0 {
+ if opts := h.Options(); len(opts) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
// rather than just throwing them away.
- aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{})
+ aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{})
if err != nil {
switch {
case
@@ -778,6 +842,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
@@ -942,13 +1007,14 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{
- stack: s,
- ids: ids,
- hashIV: hashIV,
- defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
+ p := &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
}
+ p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ return p
}
func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index c6e565455..4e4e1f3b4 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -103,6 +103,262 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
+// fields when options are supplied.
+func TestIPv4EncodeOptions(t *testing.T) {
+ tests := []struct {
+ name string
+ options header.IPv4Options
+ encodedOptions header.IPv4Options // reply should look like this
+ wantIHL int
+ }{
+ {
+ name: "valid no options",
+ wantIHL: header.IPv4MinimumSize,
+ },
+ {
+ name: "one byte options",
+ options: header.IPv4Options{1},
+ encodedOptions: header.IPv4Options{1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "two byte options",
+ options: header.IPv4Options{1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "three byte options",
+ options: header.IPv4Options{1, 1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 1, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "four byte options",
+ options: header.IPv4Options{1, 1, 1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 1, 1},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "five byte options",
+ options: header.IPv4Options{1, 1, 1, 1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 8,
+ },
+ {
+ name: "thirty nine byte options",
+ options: header.IPv4Options{
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24,
+ 25, 26, 27, 28, 29, 30, 31, 32,
+ 33, 34, 35, 36, 37, 38, 39,
+ },
+ encodedOptions: header.IPv4Options{
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24,
+ 25, 26, 27, 28, 29, 30, 31, 32,
+ 33, 34, 35, 36, 37, 38, 39, 0,
+ },
+ wantIHL: header.IPv4MinimumSize + 40,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ paddedOptionLength := test.options.SizeWithPadding()
+ ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength)
+ hdr := buffer.NewPrependable(int(totalLen))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ // To check the padding works, poison the last byte of the options space.
+ if paddedOptionLength != len(test.options) {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0xff
+ ip.SetHeaderLength(0)
+ }
+ ip.Encode(&header.IPv4Fields{
+ Options: test.options,
+ })
+ options := ip.Options()
+ wantOptions := test.encodedOptions
+ if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
+ t.Errorf("got IHL of %d, want %d", got, want)
+ }
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(wantOptions) == 0 && len(options) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(wantOptions, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv4Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ ipv4Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
+ remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv4Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv4Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
+ }
+
+ totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr1,
+ DstAddr: remoteIPv4Addr2,
+ })
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv4Addr1.Address),
+ checker.DstAddr(remoteIPv4Addr1),
+ checker.TTL(ipv4.DefaultTTL),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4TTLExceeded),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv4Addr1),
+ checker.DstAddr(remoteIPv4Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}
+
// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
// checks the response.
func TestIPv4Sanity(t *testing.T) {
@@ -197,6 +453,14 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{1, 1, 0, 0},
},
{
+ name: "Check option padding",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{1, 1, 1},
+ replyOptions: header.IPv4Options{1, 1, 1, 0},
+ },
+ {
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
maxTotalLength: ipv4.MaxTotalSize,
@@ -599,9 +863,10 @@ func TestIPv4Sanity(t *testing.T) {
},
})
- ipHeaderLength := header.IPv4MinimumSize + test.options.AllocationSize()
+ paddedOptionLength := test.options.SizeWithPadding()
+ ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(totalLen))
@@ -618,6 +883,12 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
+ // To check the padding works, poison the options space.
+ if paddedOptionLength != len(test.options) {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0x01
+ }
+
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
Protocol: test.transportProtocol,
@@ -732,7 +1003,7 @@ func TestIPv4Sanity(t *testing.T) {
}
// If the IP options change size then the packet will change size, so
// some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - len(test.options)
+ sizeChange := len(test.replyOptions) - paddedOptionLength
checker.IPv4(t, replyIPHeader,
checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
@@ -2441,9 +2712,6 @@ func TestPacketQueing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 8502b848c..beb8f562e 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -750,6 +750,12 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in
+// transit to its final destination, as per RFC 4443 section 3.3.
+type icmpReasonHopLimitExceeded struct{}
+
+func (*icmpReasonHopLimitExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -794,11 +800,27 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
+ // If we hit a Hop Limit Exceeded error, then we know we are operating as a
+ // router. As per RFC 4443 section 3.3:
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ //
+ // If we are operating as a router, do not use the packet's destination
+ // address as the response's source address as we should not own the
+ // destination address of a packet we are forwarding.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonHopLimitExceeded); ok {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -811,8 +833,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
// TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
// Unfortunately at this time ICMP Packets do not have a transport
@@ -830,6 +850,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
}
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
// As per RFC 4443 section 2.4
//
// (c) Every ICMPv6 error message (type < 128) MUST include
@@ -873,6 +895,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonHopLimitExceeded:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
@@ -896,3 +922,16 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
counter.Increment()
return nil
}
+
+// OnReassemblyTimeout implements fragmentation.TimeoutHandler.
+func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ // OnReassemblyTimeout sends a Time Exceeded Message as per RFC 2460 Section
+ // 4.5:
+ //
+ // If the first fragment (i.e., the one with a Fragment Offset of zero) has
+ // been received, an ICMP Time Exceeded -- Fragment Reassembly Time Exceeded
+ // message should be sent to the source of that fragment.
+ if pkt != nil {
+ p.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 76013daa1..9bc02d851 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
NetProto: protocol,
@@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: test.useNeighborCache,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -206,11 +205,16 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -279,10 +283,9 @@ func TestICMPCounts(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -290,7 +293,7 @@ func TestICMPCounts(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -317,13 +320,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: true,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -349,11 +347,16 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -422,10 +425,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -433,7 +435,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -1775,17 +1777,19 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
-
- // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
- r.LocalAddress = test.destination
icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.IPv6MinimumSize,
Data: buffer.View(icmp).ToVectorisedView(),
@@ -1795,10 +1799,9 @@ func TestCallsToNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.RemoteAddress,
- DstAddr: r.LocalAddress,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 0526190cc..7a00f6314 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt, params.Protocol)
-}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -640,16 +639,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt, proto)
+ return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv6(pkt.NetworkHeader().View())
+ hopLimit := h.HopLimit()
+ if hopLimit <= 1 {
+ // As per RFC 4443 section 3.3,
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ }
+
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 8200 section 3,
+ //
+ // Hop Limit 8-bit unsigned integer. Decremented by 1 by
+ // each node that forwards the packet.
+ newHdr.SetHopLimit(hopLimit - 1)
+
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(newHdr).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -669,6 +737,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+ addressEndpoint.DecRef()
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -681,8 +761,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
@@ -888,18 +967,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Set up a callback in case we need to send a Time Exceeded Message as
- // per RFC 2460 Section 4.5.
- var releaseCB func(bool)
- if start == 0 {
- pkt := pkt.Clone()
- releaseCB = func(timedOut bool) {
- if timedOut {
- _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
- }
- }
- }
-
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
data, proto, ready, err := e.protocol.fragmentation.Process(
@@ -914,17 +981,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
start+uint16(fragmentPayloadLen)-1,
extHdr.More(),
uint8(rawPayload.Identifier),
- rawPayload.Buf,
- releaseCB,
+ pkt,
)
if err != nil {
stats.IP.MalformedPacketsReceived.Increment()
stats.IP.MalformedFragmentsReceived.Increment()
return
}
- pkt.Data = data
if ready {
+ pkt.Data = data
+
// We create a new iterator with the reassembled packet because we could
// have more extension headers in the reassembled payload, as per RFC
// 8200 section 4.5. We also use the NextHeader value from the first
@@ -1335,6 +1402,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
@@ -1590,10 +1658,9 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
- stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
- ids: ids,
- hashIV: hashIV,
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
ndpDisp: opts.NDPDisp,
ndpConfigs: opts.NDPConfigs,
@@ -1601,6 +1668,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
tempIIDSeed: opts.TempIIDSeed,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
}
+ p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[*endpoint]struct{})
p.SetDefaultTTL(DefaultTTL)
return p
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1bfcdde25..a671d4bac 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -18,6 +18,7 @@ import (
"encoding/hex"
"fmt"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
@@ -2821,3 +2822,160 @@ func TestFragmentationErrors(t *testing.T) {
})
}
}
+
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv6Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ }
+ ipv6Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11::1").To16()),
+ PrefixLen: 64,
+ }
+ remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
+ remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of three",
+ TTL: 3,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv6Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv6Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
+ }
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
+ icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv6EchoRequest)
+ icmp.SetCode(header.ICMPv6UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
+ })
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv6Addr1.Address),
+ checker.DstAddr(remoteIPv6Addr1),
+ checker.TTL(DefaultTTL),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6HopLimitExceeded),
+ checker.ICMPv6Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv6Addr1),
+ checker.DstAddr(remoteIPv6Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode),
+ checker.ICMPv6Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 981d1371a..37e8b1083 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
-
{
subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
if err != nil {
@@ -73,6 +69,17 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
}
t.Cleanup(ep.Close)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := llladdr.WithPrefix()
+ if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ addressEP.DecRef()
+ }
+
return s, ep
}
@@ -961,22 +968,17 @@ func TestNDPValidation(t *testing.T) {
for _, stackTyp := range stacks {
t.Run(stackTyp.name, func(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
// Create a stack with the assigned link-local address lladdr0
// and an endpoint to lladdr1.
s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
+ return s, ep
}
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
nextHdr := uint8(header.ICMPv6ProtocolNumber)
var extensions buffer.View
if atomicFragment {
@@ -994,13 +996,12 @@ func TestNDPValidation(t *testing.T) {
PayloadLength: uint16(len(payload) + len(extensions)),
NextHeader: nextHdr,
HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -1114,8 +1115,7 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ s, ep := setup(t)
if isRouter {
// Enabling forwarding makes the stack act as a router.
@@ -1131,7 +1131,7 @@ func TestNDPValidation(t *testing.T) {
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -1152,7 +1152,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 7cc52985e..5c3363759 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st
return n, nil
}
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
-
- return nil
-}
-
// Attach implements LinkEndpoint.Attach.
func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}