summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/ndp.go22
-rw-r--r--pkg/tcpip/stack/ndp_test.go93
-rw-r--r--pkg/tcpip/stack/stack_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go6
5 files changed, 72 insertions, 59 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 783351a69..f5b750046 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -30,7 +29,6 @@ go_library(
"stack_global_state.go",
"transport_demuxer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/stack",
visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
@@ -81,7 +79,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = ["linkaddrcache_test.go"],
- embed = [":stack"],
+ library = ":stack",
deps = [
"//pkg/sleep",
"//pkg/sync",
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index d983ac390..245694118 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -538,11 +538,29 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ isValidLinkAddr := header.IsValidUnicastEthernetAddress(linkAddr)
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize
+ if isValidLinkAddr {
+ // Only include a Source Link Layer Address option if the NIC has a valid
+ // link layer address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ ndpNSSize += header.NDPLinkLayerAddressSize
+ }
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
ns := header.NDPNeighborSolicit(pkt.NDPPayload())
ns.SetTargetAddress(addr)
+
+ if isValidLinkAddr {
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr),
+ })
+ }
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
sent := r.Stats().ICMP.V6PacketsSent
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index f9460bd51..726468e41 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "context"
"encoding/binary"
"fmt"
"testing"
@@ -405,7 +406,7 @@ func TestDADResolve(t *testing.T) {
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p := <-e.C
+ p, _ := e.ReadContext(context.Background())
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
@@ -416,7 +417,11 @@ func TestDADResolve(t *testing.T) {
checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
- checker.NDPNSTargetAddress(addr1)))
+ checker.NDPNSTargetAddress(addr1),
+ checker.NDPNSOptions([]header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ }),
+ ))
}
})
}
@@ -3285,29 +3290,29 @@ func TestRouterSolicitation(t *testing.T) {
e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t,
- p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(),
- )
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
}
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(),
+ )
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet")
- case <-time.After(timeout):
}
}
s := stack.New(stack.Options{
@@ -3362,20 +3367,21 @@ func TestStopStartSolicitingRouters(t *testing.T) {
e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
}
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
@@ -3391,23 +3397,20 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Enable forwarding which should stop router solicitations.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before forwarding was enabled.
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok = e.ReadContext(ctx); ok {
t.Fatal("Should not have sent more than one RS message")
- case <-time.After(interval + defaultTimeout):
}
- case <-time.After(delay + defaultTimeout):
}
// Enabling forwarding again should do nothing.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
// Disable forwarding which should start router solicitations.
@@ -3415,17 +3418,15 @@ func TestStopStartSolicitingRouters(t *testing.T) {
waitForPkt(delay + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
- case <-time.After(interval + defaultTimeout):
}
// Disabling forwarding again should do nothing.
s.SetForwarding(false)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dad288642..834fe9487 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) {
Data: buf.ToVectorisedView(),
})
- select {
- case <-ep2.C:
- default:
+ if _, ok := ep2.Read(); !ok {
t.Fatal("Packet not forwarded")
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index f50604a8a..869c69a6d 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) {
t.Fatalf("Write failed: %v", err)
}
- var p channel.PacketInfo
- select {
- case p = <-ep2.C:
- default:
+ p, ok := ep2.Read()
+ if !ok {
t.Fatal("Response packet not forwarded")
}