summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD6
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD6
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go20
-rw-r--r--pkg/tcpip/buffer/BUILD6
-rw-r--r--pkg/tcpip/checker/BUILD3
-rw-r--r--pkg/tcpip/checker/checker.go50
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD6
-rw-r--r--pkg/tcpip/header/BUILD6
-rw-r--r--pkg/tcpip/header/checksum.go2
-rw-r--r--pkg/tcpip/header/eth.go41
-rw-r--r--pkg/tcpip/header/eth_test.go34
-rw-r--r--pkg/tcpip/header/icmpv6.go2
-rw-r--r--pkg/tcpip/header/ipv6_test.go29
-rw-r--r--pkg/tcpip/header/ndp_options.go81
-rw-r--r--pkg/tcpip/header/ndp_test.go180
-rw-r--r--pkg/tcpip/iptables/BUILD3
-rw-r--r--pkg/tcpip/iptables/iptables.go2
-rw-r--r--pkg/tcpip/iptables/types.go4
-rw-r--r--pkg/tcpip/iptables/udp_matcher.go10
-rw-r--r--pkg/tcpip/link/channel/BUILD3
-rw-r--r--pkg/tcpip/link/channel/channel.go72
-rw-r--r--pkg/tcpip/link/fdbased/BUILD6
-rw-r--r--pkg/tcpip/link/loopback/BUILD3
-rw-r--r--pkg/tcpip/link/muxed/BUILD6
-rw-r--r--pkg/tcpip/link/rawfile/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD6
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD6
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD6
-rw-r--r--pkg/tcpip/link/sniffer/BUILD3
-rw-r--r--pkg/tcpip/link/tun/BUILD3
-rw-r--r--pkg/tcpip/link/waitable/BUILD6
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD4
-rw-r--r--pkg/tcpip/network/arp/arp.go19
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD6
-rw-r--r--pkg/tcpip/network/hash/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/BUILD4
-rw-r--r--pkg/tcpip/network/ipv6/BUILD6
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go67
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go19
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go135
-rw-r--r--pkg/tcpip/ports/BUILD6
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD2
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD2
-rw-r--r--pkg/tcpip/seqnum/BUILD3
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/ndp.go25
-rw-r--r--pkg/tcpip/stack/ndp_test.go117
-rw-r--r--pkg/tcpip/stack/route.go4
-rw-r--r--pkg/tcpip/stack/stack_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/tcpip.go6
-rw-r--r--pkg/tcpip/transport/icmp/BUILD3
-rw-r--r--pkg/tcpip/transport/packet/BUILD3
-rw-r--r--pkg/tcpip/transport/raw/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/BUILD6
-rw-r--r--pkg/tcpip/transport/tcp/accept.go29
-rw-r--r--pkg/tcpip/transport/tcp/connect.go57
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go36
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go181
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go92
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go182
67 files changed, 1237 insertions, 446 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 23e4b09e7..26f7ba86b 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -12,7 +11,6 @@ go_library(
"time_unsafe.go",
"timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -25,7 +23,7 @@ go_test(
name = "tcpip_test",
size = "small",
srcs = ["tcpip_test.go"],
- embed = [":tcpip"],
+ library = ":tcpip",
)
go_test(
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 3df7d18d3..a984f1712 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "gonet",
srcs = ["gonet.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -23,7 +21,7 @@ go_test(
name = "gonet_test",
size = "small",
srcs = ["gonet_test.go"],
- embed = [":gonet"],
+ library = ":gonet",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 3bba4028b..711969b9b 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -556,6 +556,17 @@ type PacketConn struct {
wq *waiter.Queue
}
+// NewPacketConn creates a new PacketConn.
+func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn {
+ c := &PacketConn{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
// DialUDP creates a new PacketConn.
//
// If laddr is nil, a local address is automatically chosen.
@@ -580,12 +591,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := PacketConn{
- stack: s,
- ep: ep,
- wq: &wq,
- }
- c.deadlineTimer.init()
+ c := NewPacketConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -599,7 +605,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- return &c, nil
+ return c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index d6c31bfa2..563bc78ea 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"prependable.go",
"view.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer",
visibility = ["//visibility:public"],
)
@@ -17,5 +15,5 @@ go_test(
name = "buffer_test",
size = "small",
srcs = ["view_test.go"],
- embed = [":buffer"],
+ library = ":buffer",
)
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index b6fa6fc37..ed434807f 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,7 +6,6 @@ go_library(
name = "checker",
testonly = 1,
srcs = ["checker.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/checker",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 885d773b0..4d6ae0871 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -771,6 +771,56 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
}
}
+// NDPNSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ it, err := ns.Options().Iter(true)
+ if err != nil {
+ t.Errorf("opts.Iter(true): %s", err)
+ return
+ }
+
+ i := 0
+ for {
+ opt, done, _ := it.Next()
+ if done {
+ break
+ }
+
+ if i >= len(opts) {
+ t.Errorf("got unexpected option: %s", opt)
+ continue
+ }
+
+ switch wantOpt := opts[i].(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ default:
+ panic("not implemented")
+ }
+
+ i++
+ }
+
+ if missing := opts[i:]; len(missing) > 0 {
+ t.Errorf("missing options: %s", missing)
+ }
+ }
+}
+
// NDPRS creates a checker that checks that the packet contains a valid NDP
// Router Solicitation message (as per the raw wire format).
func NDPRS() NetworkChecker {
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index e648efa71..ff2719291 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "jenkins",
srcs = ["jenkins.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins",
visibility = ["//visibility:public"],
)
@@ -16,5 +14,5 @@ go_test(
srcs = [
"jenkins_test.go",
],
- embed = [":jenkins"],
+ library = ":jenkins",
)
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index cd747d100..9da0d71f8 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -24,7 +23,6 @@ go_library(
"tcp.go",
"udp.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/header",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -59,7 +57,7 @@ go_test(
"eth_test.go",
"ndp_test.go",
],
- embed = [":header"],
+ library = ":header",
deps = [
"//pkg/tcpip",
"@com_github_google_go-cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 204285576..14a4b2b44 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -213,7 +213,7 @@ func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, siz
}
v = v[:l]
- sum, odd = calculateChecksum(v, odd, uint32(sum))
+ sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum))
size -= len(v)
if size == 0 {
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index f5d2c127f..b1e92d2d7 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -134,3 +134,44 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
// addr is a valid unicast ethernet address.
return true
}
+
+// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address
+// for a multicast IPv4 address.
+//
+// addr MUST be a multicast IPv4 address.
+func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress {
+ var linkAddrBytes [EthernetAddressSize]byte
+ // RFC 1112 Host Extensions for IP Multicasting
+ //
+ // 6.4. Extensions to an Ethernet Local Network Module:
+ //
+ // An IP host group address is mapped to an Ethernet multicast
+ // address by placing the low-order 23-bits of the IP address
+ // into the low-order 23 bits of the Ethernet multicast address
+ // 01-00-5E-00-00-00 (hex).
+ linkAddrBytes[0] = 0x1
+ linkAddrBytes[2] = 0x5e
+ linkAddrBytes[3] = addr[1] & 0x7F
+ copy(linkAddrBytes[4:], addr[IPv4AddressSize-2:])
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
+
+// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address
+// for a multicast IPv6 address.
+//
+// addr MUST be a multicast IPv6 address.
+func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress {
+ // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
+ //
+ // 7. Address Mapping -- Multicast
+ //
+ // An IPv6 packet with a multicast destination address DST,
+ // consisting of the sixteen octets DST[1] through DST[16], is
+ // transmitted to the Ethernet multicast address whose first
+ // two octets are the value 3333 hexadecimal and whose last
+ // four octets are the last four octets of DST.
+ linkAddrBytes := []byte(addr[IPv6AddressSize-EthernetAddressSize:])
+ linkAddrBytes[0] = 0x33
+ linkAddrBytes[1] = 0x33
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 6634c90f5..7a0014ad9 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -66,3 +66,37 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
})
}
}
+
+func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4 Multicast without 24th bit set",
+ addr: "\xe0\x7e\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ {
+ name: "IPv4 Multicast with 24th bit set",
+ addr: "\xe0\xfe\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
+ t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr)
+ }
+ })
+ }
+}
+
+func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
+ addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a")
+ if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
+ t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index b4037b6c8..c7ee2de57 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -52,7 +52,7 @@ const (
// ICMPv6NeighborAdvertSize is size of a neighbor advertisement
// including the NDP Target Link Layer option for an Ethernet
// address.
- ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize
+ ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
// ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv6EchoMinimumSize = 8
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 29f54bc57..c3ad503aa 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -17,6 +17,7 @@ package header_test
import (
"bytes"
"crypto/sha256"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -300,3 +301,31 @@ func TestScopeForIPv6Address(t *testing.T) {
})
}
}
+
+func TestSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ addr tcpip.Address
+ want tcpip.Address
+ }{
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
+ if got := header.SolicitedNodeAddr(test.addr); got != test.want {
+ t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 06e0bace2..e6a6ad39b 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -17,6 +17,7 @@ package header
import (
"encoding/binary"
"errors"
+ "fmt"
"math"
"time"
@@ -24,13 +25,17 @@ import (
)
const (
- // NDPTargetLinkLayerAddressOptionType is the type of the Target
- // Link-Layer Address option, as per RFC 4861 section 4.6.1.
+ // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
+ NDPSourceLinkLayerAddressOptionType = 1
+
+ // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
NDPTargetLinkLayerAddressOptionType = 2
- // ndpTargetEthernetLinkLayerAddressSize is the size of a Target
- // Link Layer Option for an Ethernet address.
- ndpTargetEthernetLinkLayerAddressSize = 8
+ // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer
+ // Address option for an Ethernet address.
+ NDPLinkLayerAddressSize = 8
// NDPPrefixInformationType is the type of the Prefix Information
// option, as per RFC 4861 section 4.6.2.
@@ -189,6 +194,9 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
i.opts = i.opts[numBytes:]
switch t {
+ case NDPSourceLinkLayerAddressOptionType:
+ return NDPSourceLinkLayerAddressOption(body), false, nil
+
case NDPTargetLinkLayerAddressOptionType:
return NDPTargetLinkLayerAddressOption(body), false, nil
@@ -293,6 +301,8 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
// NDPOption is the set of functions to be implemented by all NDP option types.
type NDPOption interface {
+ fmt.Stringer
+
// Type returns the type of the receiver.
Type() uint8
@@ -368,6 +378,46 @@ func (b NDPOptionsSerializer) Length() int {
return l
}
+// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
+// as defined by RFC 4861 section 4.6.1.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
+type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
+
+// Type implements NDPOption.Type.
+func (o NDPSourceLinkLayerAddressOption) Type() uint8 {
+ return NDPSourceLinkLayerAddressOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPSourceLinkLayerAddressOption) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPSourceLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
+// EthernetAddress will return an ethernet (MAC) address if the
+// NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize
+// bytes. If the body has more than EthernetAddressSize bytes, only the first
+// EthernetAddressSize bytes are returned as that is all that is needed for an
+// Ethernet address.
+func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
+ if len(o) >= EthernetAddressSize {
+ return tcpip.LinkAddress(o[:EthernetAddressSize])
+ }
+
+ return tcpip.LinkAddress([]byte(nil))
+}
+
// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option
// as defined by RFC 4861 section 4.6.1.
//
@@ -390,6 +440,11 @@ func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
+// String implements fmt.Stringer.String.
+func (o NDPTargetLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
// EthernetAddress will return an ethernet (MAC) address if the
// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize
// bytes. If the body has more than EthernetAddressSize bytes, only the first
@@ -436,6 +491,17 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int {
return used
}
+// String implements fmt.Stringer.String.
+func (o NDPPrefixInformation) String() string {
+ return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)",
+ o,
+ o.OnLinkFlag(),
+ o.AutonomousAddressConfigurationFlag(),
+ o.PreferredLifetime(),
+ o.ValidLifetime(),
+ o.Subnet())
+}
+
// PrefixLength returns the value in the number of leading bits in the Prefix
// that are valid.
//
@@ -545,6 +611,11 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
return used
}
+// String implements fmt.Stringer.String.
+func (o NDPRecursiveDNSServer) String() string {
+ return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime())
+}
+
// Lifetime returns the length of time that the DNS server addresses
// in this option may be used for name resolution.
//
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 2c439d70c..1cb9f5dc8 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -153,6 +153,125 @@ func TestNDPRouterAdvert(t *testing.T) {
}
}
+// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the
+// Ethernet address from an NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected tcpip.LinkAddress
+ }{
+ {
+ "ValidMAC",
+ []byte{1, 2, 3, 4, 5, 6},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ "SLLBodyTooShort",
+ []byte{1, 2, 3, 4, 5},
+ tcpip.LinkAddress([]byte(nil)),
+ },
+ {
+ "SLLBodyLargerThanNeeded",
+ []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ sll := NDPSourceLinkLayerAddressOption(test.buf)
+ if got := sll.EthernetAddress(); got != test.expected {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected)
+ }
+ })
+ }
+}
+
+// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a
+// NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedBuf []byte
+ addr tcpip.LinkAddress
+ }{
+ {
+ "Ethernet",
+ make([]byte, 8),
+ []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ "\x01\x02\x03\x04\x05\x06",
+ },
+ {
+ "Padding",
+ []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ "\x01\x02\x03\x04\x05\x06\x07\x08",
+ },
+ {
+ "Empty",
+ nil,
+ nil,
+ "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ serializer := NDPOptionsSerializer{
+ NDPSourceLinkLayerAddressOption(test.addr),
+ }
+ if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
+ t.Fatalf("got Length = %d, want = %d", got, want)
+ }
+ opts.Serialize(serializer)
+ if !bytes.Equal(test.buf, test.expectedBuf) {
+ t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ if len(test.expectedBuf) > 0 {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+ sll := next.(NDPSourceLinkLayerAddressOption)
+ if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
+ t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+
+ if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+
+ // Iterator should not return anything else.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
// Ethernet address from an NDPTargetLinkLayerAddressOption.
func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
@@ -186,7 +305,6 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
}
})
}
-
}
// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
@@ -212,8 +330,8 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
},
{
"Empty",
- []byte{},
- []byte{},
+ nil,
+ nil,
"",
},
}
@@ -246,7 +364,7 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
}
tll := next.(NDPTargetLinkLayerAddressOption)
if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
@@ -254,7 +372,7 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
}
if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got tll.MACAddress = %s, want = %s", got, want)
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
}
}
@@ -510,7 +628,7 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
}
opt, ok := next.(NDPRecursiveDNSServer)
@@ -553,6 +671,16 @@ func TestNDPOptionsIterCheck(t *testing.T) {
ErrNDPOptZeroLength,
},
{
+ "ValidSourceLinkLayerAddressOption",
+ []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ nil,
+ },
+ {
+ "TooSmallSourceLinkLayerAddressOption",
+ []byte{1, 1, 1, 2, 3, 4, 5},
+ ErrNDPOptBufExhausted,
+ },
+ {
"ValidTargetLinkLayerAddressOption",
[]byte{2, 1, 1, 2, 3, 4, 5, 6},
nil,
@@ -603,10 +731,13 @@ func TestNDPOptionsIterCheck(t *testing.T) {
ErrNDPOptMalformedBody,
},
{
- "ValidTargetLinkLayerAddressWithPrefixInformation",
+ "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
[]byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
// Target Link-Layer Address.
- 2, 1, 1, 2, 3, 4, 5, 6,
+ 2, 1, 7, 8, 9, 10, 11, 12,
// Prefix information.
3, 4, 43, 64,
@@ -621,10 +752,13 @@ func TestNDPOptionsIterCheck(t *testing.T) {
nil,
},
{
- "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
+ "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
[]byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
// Target Link-Layer Address.
- 2, 1, 1, 2, 3, 4, 5, 6,
+ 2, 1, 7, 8, 9, 10, 11, 12,
// 255 is an unrecognized type. If 255 ends up
// being the type for some recognized type,
@@ -714,8 +848,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
// here.
func TestNDPOptionsIter(t *testing.T) {
buf := []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
// Target Link-Layer Address.
- 2, 1, 1, 2, 3, 4, 5, 6,
+ 2, 1, 7, 8, 9, 10, 11, 12,
// 255 is an unrecognized type. If 255 ends up being the type
// for some recognized type, update 255 to some other
@@ -740,7 +877,7 @@ func TestNDPOptionsIter(t *testing.T) {
t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
}
- // Test the first (Taret Link-Layer) option.
+ // Test the first (Source Link-Layer) option.
next, done, err := it.Next()
if err != nil {
t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
@@ -748,7 +885,22 @@ func TestNDPOptionsIter(t *testing.T) {
if done {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
- if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
+ if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+
+ // Test the next (Target Link-Layer) option.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
@@ -764,7 +916,7 @@ func TestNDPOptionsIter(t *testing.T) {
if done {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
- if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) {
+ if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
if got := next.Type(); got != NDPPrefixInformationType {
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index e41c645ed..bab26580b 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -10,7 +10,6 @@ go_library(
"types.go",
"udp_matcher.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index accedba1e..1b9485bbd 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index ba5ed75b4..2ea8994ae 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -132,7 +132,7 @@ type Table struct {
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook, _ := range table.BuiltinChains {
+ for hook := range table.BuiltinChains {
hooks |= 1 << hook
}
return hooks
diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/tcpip/iptables/udp_matcher.go
index 3bb076f9c..496931d7a 100644
--- a/pkg/tcpip/iptables/udp_matcher.go
+++ b/pkg/tcpip/iptables/udp_matcher.go
@@ -54,7 +54,7 @@ func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherParams) (Matcher, error
}
if filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber)
}
return &UDPMatcher{Data: data}, nil
@@ -73,7 +73,6 @@ func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str
// We dont't match fragments.
if frag := netHeader.FragmentOffset(); frag != 0 {
if frag == 1 {
- log.Warningf("Dropping UDP packet: malicious fragmented packet.")
return false, true
}
return false, false
@@ -92,7 +91,6 @@ func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str
if len(pkt.Data.First()) < header.UDPMinimumSize {
// There's no valid UDP header here, so we hotdrop the
// packet.
- log.Warningf("Dropping UDP packet: size too small.")
return false, true
}
udpHeader = header.UDP(pkt.Data.First())
@@ -100,12 +98,10 @@ func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str
// Check whether the source and destination ports are within the
// matching range.
- sourcePort := udpHeader.SourcePort()
- destinationPort := udpHeader.DestinationPort()
- if sourcePort < um.Data.SourcePortStart || um.Data.SourcePortEnd < sourcePort {
+ if sourcePort := udpHeader.SourcePort(); sourcePort < um.Data.SourcePortStart || um.Data.SourcePortEnd < sourcePort {
return false, false
}
- if destinationPort < um.Data.DestinationPortStart || um.Data.DestinationPortEnd < destinationPort {
+ if destinationPort := udpHeader.DestinationPort(); destinationPort < um.Data.DestinationPortStart || um.Data.DestinationPortEnd < destinationPort {
return false, false
}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 7dbc05754..3974c464e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "channel",
srcs = ["channel.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 70188551f..78d447acd 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -18,6 +18,8 @@
package channel
import (
+ "context"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -28,35 +30,63 @@ type PacketInfo struct {
Pkt tcpip.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
+ Route stack.Route
}
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
- dispatcher stack.NetworkDispatcher
- mtu uint32
- linkAddr tcpip.LinkAddress
- GSO bool
+ dispatcher stack.NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+ LinkEPCapabilities stack.LinkEndpointCapabilities
- // C is where outbound packets are queued.
- C chan PacketInfo
+ // c is where outbound packets are queued.
+ c chan PacketInfo
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- C: make(chan PacketInfo, size),
+ c: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
}
+// Close closes e. Further packet injections will panic. Reads continue to
+// succeed until all packets are read.
+func (e *Endpoint) Close() {
+ close(e.c)
+}
+
+// Read does non-blocking read for one packet from the outbound packet queue.
+func (e *Endpoint) Read() (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+// ReadContext does blocking read for one packet from the outbound packet queue.
+// It can be cancelled by ctx, and in this case, it returns false.
+func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
select {
- case <-e.C:
+ case <-e.c:
c++
default:
return c
@@ -93,11 +123,7 @@ func (e *Endpoint) MTU() uint32 {
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- caps := stack.LinkEndpointCapabilities(0)
- if e.GSO {
- caps |= stack.CapabilityHardwareGSO
- }
- return caps
+ return e.LinkEPCapabilities
}
// GSOMaxSize returns the maximum GSO packet size.
@@ -117,15 +143,20 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
+ Route: route,
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
@@ -133,7 +164,11 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
payloadView := pkts[0].Data.ToView()
n := 0
packetLoop:
@@ -147,10 +182,11 @@ packetLoop:
},
Proto: protocol,
GSO: gso,
+ Route: route,
}
select {
- case e.C <- p:
+ case e.c <- p:
n++
default:
break packetLoop
@@ -169,7 +205,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 66cc53ed4..abe725548 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -13,7 +12,6 @@ go_library(
"mmap_unsafe.go",
"packet_dispatchers.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -30,7 +28,7 @@ go_test(
name = "fdbased_test",
size = "small",
srcs = ["endpoint_test.go"],
- embed = [":fdbased"],
+ library = ":fdbased",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
index f35fcdff4..6bf3805b7 100644
--- a/pkg/tcpip/link/loopback/BUILD
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "loopback",
srcs = ["loopback.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 1ac7948b6..82b441b79 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "muxed",
srcs = ["injectable.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -19,7 +17,7 @@ go_test(
name = "muxed_test",
size = "small",
srcs = ["injectable_test.go"],
- embed = [":muxed"],
+ library = ":muxed",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index d8211e93d..14b527bc2 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -12,7 +12,6 @@ go_library(
"errors.go",
"rawfile_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 09165dd4c..13243ebbb 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,7 +10,6 @@ go_library(
"sharedmem_unsafe.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -30,7 +28,7 @@ go_test(
srcs = [
"sharedmem_test.go",
],
- embed = [":sharedmem"],
+ library = ":sharedmem",
deps = [
"//pkg/sync",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index a0d4ad0be..87020ec08 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,7 +10,6 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe",
visibility = ["//visibility:public"],
)
@@ -20,6 +18,6 @@ go_test(
srcs = [
"pipe_test.go",
],
- embed = [":pipe"],
+ library = ":pipe",
deps = ["//pkg/sync"],
)
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index 8c9234d54..3ba06af73 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -22,7 +20,7 @@ go_test(
srcs = [
"queue_test.go",
],
- embed = [":queue"],
+ library = ":queue",
deps = [
"//pkg/tcpip/link/sharedmem/pipe",
],
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index d6ae0368a..230a8d53a 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,7 +8,6 @@ go_library(
"pcap.go",
"sniffer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index a71a493fc..e5096ea38 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -1,10 +1,9 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "tun",
srcs = ["tun_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun",
visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 134837943..0956d2c65 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,7 +7,6 @@ go_library(
srcs = [
"waitable.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable",
visibility = ["//visibility:public"],
deps = [
"//pkg/gate",
@@ -23,7 +21,7 @@ go_test(
srcs = [
"waitable_test.go",
],
- embed = [":waitable"],
+ library = ":waitable",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 9d16ff8c9..6a4839fb8 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index e7617229b..eddf7b725 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "arp",
srcs = ["arp.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 1ceaebfbd..4da13c5df 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -178,24 +178,9 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
return broadcastMAC, true
}
if header.IsV4MulticastAddress(addr) {
- // RFC 1112 Host Extensions for IP Multicasting
- //
- // 6.4. Extensions to an Ethernet Local Network Module:
- //
- // An IP host group address is mapped to an Ethernet multicast
- // address by placing the low-order 23-bits of the IP address
- // into the low-order 23 bits of the Ethernet multicast address
- // 01-00-5E-00-00-00 (hex).
- return tcpip.LinkAddress([]byte{
- 0x01,
- 0x00,
- 0x5e,
- addr[header.IPv4AddressSize-3] & 0x7f,
- addr[header.IPv4AddressSize-2],
- addr[header.IPv4AddressSize-1],
- }), true
+ return header.EthernetAddressFromMulticastIPv4Address(addr), true
}
- return "", false
+ return tcpip.LinkAddress([]byte(nil)), false
}
// SetOption implements NetworkProtocol.
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 8e6048a21..03cf03b6d 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,6 +15,7 @@
package arp_test
import (
+ "context"
"strconv"
"testing"
"time"
@@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP.C)
+ c.linkEP.Close()
}
func TestDirectRequest(t *testing.T) {
@@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) {
for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
- pi := <-c.linkEP.C
+ pi, _ := c.linkEP.ReadContext(context.Background())
if pi.Proto != arp.ProtocolNumber {
t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
@@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) {
}
inject(stackAddrBad)
- select {
- case pkt := <-c.linkEP.C:
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
+ ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
- case <-time.After(100 * time.Millisecond):
- // Sleep tests are gross, but this will only potentially flake
- // if there's a bug. If there is no bug this will reliably
- // succeed.
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index ed16076fd..d1c728ccf 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -24,7 +23,6 @@ go_library(
"reassembler.go",
"reassembler_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -42,6 +40,6 @@ go_test(
"fragmentation_test.go",
"reassembler_test.go",
],
- embed = [":fragmentation"],
+ library = ":fragmentation",
deps = ["//pkg/tcpip/buffer"],
)
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
index e6db5c0b0..872165866 100644
--- a/pkg/tcpip/network/hash/BUILD
+++ b/pkg/tcpip/network/hash/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "hash",
srcs = ["hash.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash",
visibility = ["//visibility:public"],
deps = [
"//pkg/rand",
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 4e2aae9a3..0fef2b1f1 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"icmp.go",
"ipv4.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index e4e273460..fb11874c6 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"icmp.go",
"ipv6.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -27,7 +25,7 @@ go_test(
"ipv6_test.go",
"ndp_test.go",
],
- embed = [":ipv6"],
+ library = ":ipv6",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 1c3410618..60817d36d 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -137,21 +137,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
}
ns := header.NDPNeighborSolicit(h.NDPPayload())
+ it, err := ns.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NS option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
targetAddr := ns.TargetAddress()
s := r.Stack()
rxNICID := r.NICID()
-
- isTentative, err := s.IsAddrTentative(rxNICID, targetAddr)
- if err != nil {
+ if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil {
// We will only get an error if rxNICID is unrecognized,
// which should not happen. For now short-circuit this
// packet.
//
// TODO(b/141002840): Handle this better?
return
- }
-
- if isTentative {
+ } else if isTentative {
// If the target address is tentative and the source
// of the packet is a unicast (specified) address, then
// the source of the packet is attempting to perform
@@ -185,6 +188,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
return
}
+ // If the NS message has the source link layer option, update the link
+ // address cache with the link address for the sender of the message.
+ //
+ // TODO(b/148429853): Properly process the NS message and do Neighbor
+ // Unreachability Detection.
+ for {
+ opt, done, _ := it.Next()
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress())
+ }
+ }
+
optsSerializer := header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
}
@@ -211,15 +231,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
r.LocalAddress = targetAddr
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- // TODO(tamird/ghanan): there exists an explicit NDP option that is
- // used to update the neighbor table with link addresses for a
- // neighbor from an NS (see the Source Link Layer option RFC
- // 4861 section 4.6.1 and section 7.2.3).
- //
- // Furthermore, the entirety of NDP handling here seems to be
- // contradicted by RFC 4861.
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
-
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
//
// 7.1.2. Validation of Neighbor Advertisements
@@ -397,10 +408,14 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
+
+ // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
+ // route here. Note, we would need the nicID to do this properly so the right
+ // NIC (associated to linkEP) is used to send the NDP NS message.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: broadcastMAC,
+ RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
@@ -430,23 +445,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
// ResolveStaticAddress implements stack.LinkAddressResolver.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if header.IsV6MulticastAddress(addr) {
- // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
- //
- // 7. Address Mapping -- Multicast
- //
- // An IPv6 packet with a multicast destination address DST,
- // consisting of the sixteen octets DST[1] through DST[16], is
- // transmitted to the Ethernet multicast address whose first
- // two octets are the value 3333 hexadecimal and whose last
- // four octets are the last four octets of DST.
- return tcpip.LinkAddress([]byte{
- 0x33,
- 0x33,
- addr[header.IPv6AddressSize-4],
- addr[header.IPv6AddressSize-3],
- addr[header.IPv6AddressSize-2],
- addr[header.IPv6AddressSize-1],
- }), true
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
}
- return "", false
+ return tcpip.LinkAddress([]byte(nil)), false
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index a2fdc5dcd..d0e930e20 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "context"
"reflect"
"strings"
"testing"
@@ -264,19 +265,20 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP0.C)
- close(c.linkEP1.C)
+ c.linkEP0.Close()
+ c.linkEP1.Close()
}
type routeArgs struct {
- src, dst *channel.Endpoint
- typ header.ICMPv6Type
+ src, dst *channel.Endpoint
+ typ header.ICMPv6Type
+ remoteLinkAddr tcpip.LinkAddress
}
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pi := <-args.src.C
+ pi, _ := args.src.ReadContext(context.Background())
{
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
@@ -291,6 +293,11 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
t.Errorf("unexpected protocol number %d", pi.Proto)
return
}
+
+ if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
+ t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ }
+
ipv6 := header.IPv6(pi.Pkt.Header.View())
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
@@ -338,7 +345,7 @@ func TestLinkResolution(t *testing.T) {
t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
}
for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit},
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
} {
routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index fe895b376..bd732f93f 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -70,6 +70,141 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
return s, ep
}
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an
+// NDP NS message with the Source Link Layer Address option results in a
+// new entry in the link address cache for the sender of the message.
+func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+ if linkAddr != linkAddr1 {
+ t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1)
+ }
+}
+
+// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving
+// an NDP NS message with an invalid Source Link Layer Address option does not
+// result in a new entry in the link address cache for the sender of the
+// message.
+func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ }{
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 1, 2, 3, 4, 5},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+ if linkAddr != "" {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr)
+ }
+ })
+ }
+}
+
// TestHopLimitValidation is a test that makes sure that NDP packets are only
// received if their IP header's hop limit is set to 255.
func TestHopLimitValidation(t *testing.T) {
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index a6ef3bdcc..2bad05a2e 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,12 +1,10 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "ports",
srcs = ["ports.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/ports",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -17,7 +15,7 @@ go_library(
go_test(
name = "ports_test",
srcs = ["ports_test.go"],
- embed = [":ports"],
+ library = ":ports",
deps = [
"//pkg/tcpip",
],
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index d7496fde6..cf0a5fefe 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
index 875561566..43264b76d 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
index b31ddba2f..45f503845 100644
--- a/pkg/tcpip/seqnum/BUILD
+++ b/pkg/tcpip/seqnum/BUILD
@@ -1,10 +1,9 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "seqnum",
srcs = ["seqnum.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum",
visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 783351a69..f5b750046 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -30,7 +29,6 @@ go_library(
"stack_global_state.go",
"transport_demuxer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/stack",
visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
@@ -81,7 +79,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = ["linkaddrcache_test.go"],
- embed = [":stack"],
+ library = ":stack",
deps = [
"//pkg/sleep",
"//pkg/sync",
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index d983ac390..6123fda33 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -167,8 +167,8 @@ type NDPDispatcher interface {
// reason, such as the address being removed). If an error occured
// during DAD, err will be set and resolved must be ignored.
//
- // This function is permitted to block indefinitely without interfering
- // with the stack's operation.
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
// OnDefaultRouterDiscovered will be called when a new default router is
@@ -538,6 +538,14 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
+ // Route should resolve immediately since snmc is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)
+ } else if c != nil {
+ log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())
+ }
+
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
@@ -589,8 +597,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndp.nic.stack.ndpDisp != nil {
- go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
}
}
@@ -1197,6 +1205,15 @@ func (ndp *ndpState) startSolicitingRouters() {
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
+ // Route should resolve immediately since
+ // header.IPv6AllRoutersMulticastAddress is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)
+ } else if c != nil {
+ log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())
+ }
+
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index f9460bd51..8af8565f7 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "context"
"encoding/binary"
"fmt"
"testing"
@@ -335,6 +336,7 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(opts)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -405,18 +407,32 @@ func TestDADResolve(t *testing.T) {
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p := <-e.C
+ p, _ := e.ReadContext(context.Background())
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
- // Check NDP packet.
+ // Make sure the right remote link address is used.
+ snmc := header.SolicitedNodeAddr(addr1)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ // Check NDP NS packet.
+ //
+ // As per RFC 4861 section 4.3, a possible option is the Source Link
+ // Layer option, but this option MUST NOT be included when the source
+ // address of the packet is the unspecified address.
checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
- checker.NDPNSTargetAddress(addr1)))
+ checker.NDPNSTargetAddress(addr1),
+ checker.NDPNSOptions(nil),
+ ))
}
})
}
@@ -492,7 +508,7 @@ func TestDADFail(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
ndpConfigs := stack.DefaultNDPConfigurations()
opts := stack.Options{
@@ -571,7 +587,7 @@ func TestDADFail(t *testing.T) {
// removed.
func TestDADStop(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
ndpConfigs := stack.NDPConfigurations{
RetransmitTimer: time.Second,
@@ -3283,31 +3299,38 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t,
- p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(),
- )
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
+
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(),
+ )
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet")
- case <-time.After(timeout):
}
}
s := stack.New(stack.Options{
@@ -3362,20 +3385,21 @@ func TestStopStartSolicitingRouters(t *testing.T) {
e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
@@ -3391,23 +3415,20 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Enable forwarding which should stop router solicitations.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before forwarding was enabled.
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok = e.ReadContext(ctx); ok {
t.Fatal("Should not have sent more than one RS message")
- case <-time.After(interval + defaultTimeout):
}
- case <-time.After(delay + defaultTimeout):
}
// Enabling forwarding again should do nothing.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
// Disable forwarding which should start router solicitations.
@@ -3415,17 +3436,15 @@ func TestStopStartSolicitingRouters(t *testing.T) {
waitForPkt(delay + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
- case <-time.After(interval + defaultTimeout):
}
// Disabling forwarding again should do nothing.
s.SetForwarding(false)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 517f4b941..f565aafb2 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -225,7 +225,9 @@ func (r *Route) Release() {
// Clone Clone a route such that the original one can be released and the new
// one will remain valid.
func (r *Route) Clone() Route {
- r.ref.incRef()
+ if r.ref != nil {
+ r.ref.incRef()
+ }
return *r
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dad288642..834fe9487 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) {
Data: buf.ToVectorisedView(),
})
- select {
- case <-ep2.C:
- default:
+ if _, ok := ep2.Read(); !ok {
t.Fatal("Packet not forwarded")
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index f50604a8a..869c69a6d 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) {
t.Fatalf("Write failed: %v", err)
}
- var p channel.PacketInfo
- select {
- case p = <-ep2.C:
- default:
+ p, ok := ep2.Read()
+ if !ok {
t.Fatal("Response packet not forwarded")
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 59c9b3fb0..0fa141d58 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -626,6 +626,12 @@ type TCPLingerTimeoutOption time.Duration
// before being marked closed.
type TCPTimeWaitTimeoutOption time.Duration
+// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
+// accept to return a completed connection only when there is data to be
+// read. This usually means the listening socket will drop the final ACK
+// for a handshake till the specified timeout until a segment with data arrives.
+type TCPDeferAcceptOption time.Duration
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 3aa23d529..ac18ec5b1 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,7 +23,6 @@ go_library(
"icmp_packet_list.go",
"protocol.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index 4858d150c..d22de6b26 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -22,7 +22,6 @@ go_library(
"endpoint_state.go",
"packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2f2131ff7..c9baf4600 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,7 +23,6 @@ go_library(
"protocol.go",
"raw_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 0e3ab05ad..272e8f570 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -55,10 +54,10 @@ go_library(
"tcp_segment_list.go",
"timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
@@ -92,6 +91,7 @@ go_test(
tags = ["flaky"],
deps = [
":tcp",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d469758eb..08afb7c17 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
- n := newEndpoint(l.stack, netProto, nil)
+ n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6only
n.ID = s.id
n.boundNICID = s.route.NICID()
@@ -271,18 +271,19 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
return n, nil
}
-// createEndpoint creates a new endpoint in connected state and then performs
-// the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+// createEndpointAndPerformHandshake creates a new endpoint in connected state
+// and then performs the TCP 3-way handshake.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
if err != nil {
return nil, err
}
// listenEP is nil when listenContext is used by tcp.Forwarder.
+ deferAccept := time.Duration(0)
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
@@ -290,13 +291,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+ deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
// Perform the 3-way handshake.
- h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
-
- h.resetToSynRcvd(isn, irs, opts)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
if l.listenEP != nil {
@@ -377,16 +377,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer e.decSynRcvdCount()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(n)
@@ -546,7 +544,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -576,8 +574,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// space available in the backlog.
// Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4e3c5419c..5c5397823 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -86,6 +86,19 @@ type handshake struct {
// rcvWndScale is the receive window scale, as defined in RFC 1323.
rcvWndScale int
+
+ // startTime is the time at which the first SYN/SYN-ACK was sent.
+ startTime time.Time
+
+ // deferAccept if non-zero will drop the final ACK for a passive
+ // handshake till an ACK segment with data is received or the timeout is
+ // hit.
+ deferAccept time.Duration
+
+ // acked is true if the the final ACK for a 3-way handshake has
+ // been received. This is required to stop retransmitting the
+ // original SYN-ACK when deferAccept is enabled.
+ acked bool
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
@@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
return h
}
+func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
+ h := newHandshake(ep, rcvWnd)
+ h.resetToSynRcvd(isn, irs, opts, deferAccept)
+ return h
+}
+
// FindWndScale determines the window scale to use for the given maximum window
// size.
func FindWndScale(wnd seqnum.Size) int {
@@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.ackNum = irs + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
+ h.deferAccept = deferAccept
h.ep.mu.Lock()
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
@@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
+ // If deferAccept is not zero and this is a bare ACK and the
+ // timeout is not hit then drop the ACK.
+ if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ h.acked = true
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
@@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+
h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
+ // If the segment has data then requeue it for the receiver
+ // to process it again once main loop is started.
+ if s.data.Size() > 0 {
+ s.incRef()
+ h.ep.enqueueSegment(s)
+ }
h.ep.mu.Unlock()
-
return nil
}
@@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.startTime = time.Now()
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
@@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error {
switch index, _ := s.Fetch(true); index {
case wakerForResend:
timeOut *= 2
- if timeOut > 60*time.Second {
+ if timeOut > MaxRTO {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ // Resend the SYN/SYN-ACK only if the following conditions hold.
+ // - It's an active handshake (deferAccept does not apply)
+ // - It's a passive handshake and we have not yet got the final-ACK.
+ // - It's a passive handshake and we got an ACK but deferAccept is
+ // enabled and we are now past the deferAccept duration.
+ // The last is required to provide a way for the peer to complete
+ // the connection with another ACK or data (as ACKs are never
+ // retransmitted on their own).
+ if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ }
case wakerForNotification:
n := h.ep.fetchNotifications()
@@ -944,6 +989,10 @@ func (e *endpoint) transitionToStateCloseLocked() {
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ // Dual-stack socket, try IPv4.
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ }
if ep == nil {
replyWithReset(s)
s.decRef()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 13718ff55..b5a8e15ee 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -498,6 +498,13 @@ type endpoint struct {
// without any data being acked.
userTimeout time.Duration
+ // deferAccept if non-zero specifies a user specified time during
+ // which the final ACK of a handshake will be dropped provided the
+ // ACK is a bare ACK and carries no data. If the timeout is crossed then
+ // the bare ACK is accepted and the connection is delivered to the
+ // listener.
+ deferAccept time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
@@ -1574,6 +1581,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ if time.Duration(v) > MaxRTO {
+ v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ }
+ e.deferAccept = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1798,6 +1814,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ *o = tcpip.TCPDeferAcceptOption(e.deferAccept)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2025,8 +2047,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// work mutex is available.
if e.workMu.TryLock() {
e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.notifyProtocolGoroutine(notifyTickleWorker)
+ // We need to double check here to make
+ // sure worker has not transitioned the
+ // endpoint out of a connected state
+ // before trying to send a reset.
+ if e.EndpointState().connected() {
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
e.mu.Unlock()
e.workMu.Unlock()
} else {
@@ -2149,9 +2177,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
-func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+func (e *endpoint) startAcceptedLoop() {
e.mu.Lock()
- e.waiterQueue = waiterQueue
e.workerRunning = true
e.mu.Unlock()
wakerInitDone := make(chan struct{})
@@ -2177,7 +2204,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
-
return n, n.waiterQueue, nil
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 7eb613be5..c9ee5bf06 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
- })
+ }, queue)
if err != nil {
return nil, err
}
// Start the protocol goroutine.
- ep.startAcceptedLoop(queue)
+ ep.startAcceptedLoop()
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index df2fb1071..2c1505067 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -6787,3 +6788,183 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
),
)
}
+
+func TestTCPDeferAccept(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give a bit of time for the socket to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestTCPDeferAcceptTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Sleep for a little of the tcpDeferAccept timeout.
+ time.Sleep(tcpDeferAccept + 100*time.Millisecond)
+
+ // On timeout expiry we should get a SYN-ACK retransmission.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give sometime for the endpoint to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestResetDuringClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ iss := seqnum.Value(789)
+ c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
+ // Send some data to make sure there is some unread
+ // data to trigger a reset on c.Close.
+ irs := c.IRS
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss.Add(1),
+ AckNum: irs.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(irs.Add(1))),
+ checker.AckNum(uint32(iss.Add(5)))))
+
+ // Close in a separate goroutine so that we can trigger
+ // a race with the RST we send below. This should not
+ // panic due to the route being released depeding on
+ // whether Close() sends an active RST or the RST sent
+ // below is processed by the worker first.
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(5),
+ AckNum: c.IRS.Add(5),
+ RcvWnd: 30000,
+ Flags: header.TCPFlagRst,
+ })
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.EP.Close()
+ }()
+
+ wg.Wait()
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index b33ec2087..ce6a2c31d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,7 +6,6 @@ go_library(
name = "context",
testonly = 1,
srcs = ["context.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
"//visibility:public",
],
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 822907998..1e9a0dea3 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -18,6 +18,7 @@ package context
import (
"bytes"
+ "context"
"testing"
"time"
@@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack {
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
- select {
- case <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), wait)
+ if _, ok := c.linkEP.ReadContext(ctx); ok {
c.t.Fatal(errMsg)
-
- case <-time.After(wait):
}
}
@@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) {
// 2 seconds.
func (c *Context) GetPacket() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
- if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
- }
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
+ if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
+ c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
}
- return nil
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// GetPacketNonBlocking reads a packet from the link layer endpoint
@@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte {
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
- default:
+ p, ok := c.linkEP.Read()
+ if !ok {
return nil
}
+
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
@@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
-
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
- return nil
+ checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
+ return b
}
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
@@ -1082,7 +1082,11 @@ func (c *Context) SACKEnabled() bool {
// SetGSOEnabled enables or disables generic segmentation offload.
func (c *Context) SetGSOEnabled(enable bool) {
- c.linkEP.GSO = enable
+ if enable {
+ c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
+ } else {
+ c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
+ }
}
// MSSWithoutOptions returns the value for the MSS used by the stack when no
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 43fcc27f0..3ad6994a7 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "tcpconntrack",
srcs = ["tcp_conntrack.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 57ff123e3..adc908e24 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -25,7 +24,6 @@ go_library(
"protocol.go",
"udp_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c6927cfe3..f0ff3fe71 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "context"
"fmt"
"math/rand"
"testing"
@@ -357,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- h := flow.header4Tuple(outgoing)
- checkers := append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
}
- return nil
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ h := flow.header4Tuple(outgoing)
+ checkers = append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
+ return b
}
// injectPacket creates a packet of the given flow and with the given payload,
@@ -1541,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
- }
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
}
})
}
@@ -1615,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
+
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
}
})
}