summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2018-09-12 20:38:27 -0700
committerShentubot <shentubot@google.com>2018-09-12 20:39:24 -0700
commit5adb3468d4de249df055d641e01ce6582b3a9388 (patch)
treefa75f573912b3647dcc7158961aa1085e572a8a1 /pkg/tcpip
parent9dec7a3db99d8c7045324bc6d8f0c27e88407f6c (diff)
Add multicast support
PiperOrigin-RevId: 212750821 Change-Id: I822fd63e48c684b45fd91f9ce057867b7eceb792
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/checker/checker.go16
-rw-r--r--pkg/tcpip/header/ipv4.go16
-rw-r--r--pkg/tcpip/header/ipv6.go9
-rw-r--r--pkg/tcpip/network/arp/arp.go7
-rw-r--r--pkg/tcpip/network/ip_test.go4
-rw-r--r--pkg/tcpip/network/ipv4/BUILD18
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go13
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go92
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go5
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go22
-rw-r--r--pkg/tcpip/stack/nic.go53
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go9
-rw-r--r--pkg/tcpip/stack/stack.go25
-rw-r--r--pkg/tcpip/stack/stack_test.go8
-rw-r--r--pkg/tcpip/stack/transport_test.go2
-rw-r--r--pkg/tcpip/tcpip.go22
-rw-r--r--pkg/tcpip/transport/ping/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/connect.go58
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go123
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go235
23 files changed, 602 insertions, 150 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 518719de4..8e0e49efa 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -86,6 +86,22 @@ func DstAddr(addr tcpip.Address) NetworkChecker {
}
}
+// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
+func TTL(ttl uint8) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ var v uint8
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TTL()
+ case header.IPv6:
+ v = ip.HopLimit()
+ }
+ if v != ttl {
+ t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ }
+ }
+}
+
// PayloadLen creates a checker that checks the payload length.
func PayloadLen(plen int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 950c54f74..29570cc34 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -93,6 +93,12 @@ const (
// IPv4Version is the version of the ipv4 protocol.
IPv4Version = 4
+
+ // IPv4Broadcast is the broadcast address of the IPv4 procotol.
+ IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
+
+ // IPv4Any is the non-routable IPv4 "any" meta address.
+ IPv4Any tcpip.Address = "\x00\x00\x00\x00"
)
// Flags that may be set in an IPv4 packet.
@@ -259,3 +265,13 @@ func (b IPv4) IsValid(pktSize int) bool {
return true
}
+
+// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
+// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
+// will be 1110 = 0xe0.
+func IsV4MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return (addr[0] & 0xf0) == 0xe0
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 62bc23c27..066e48a9a 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -193,3 +193,12 @@ func IsV4MappedAddress(addr tcpip.Address) bool {
return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
}
+
+// IsV6MulticastAddress determines if the provided address is an IPv6
+// multicast address (anything starting with FF).
+func IsV6MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xff
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6bf4be868..8f64e3f42 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -51,6 +51,11 @@ type endpoint struct {
linkAddrCache stack.LinkAddressCache
}
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
func (e *endpoint) MTU() uint32 {
lmtu := e.linkEP.MTU()
return lmtu - uint32(e.MaxHeaderLength())
@@ -74,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 1e92b7ae9..0cb53fb42 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -222,7 +222,7 @@ func TestIPv4Send(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
vv := buffer.NewVectorisedView(len(payload), []buffer.View{payload})
- if err := ep.WritePacket(&r, &hdr, vv, 123); err != nil {
+ if err := ep.WritePacket(&r, &hdr, vv, 123, 123); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -461,7 +461,7 @@ func TestIPv6Send(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
vv := buffer.NewVectorisedView(len(payload), []buffer.View{payload})
- if err := ep.WritePacket(&r, &hdr, vv, 123); err != nil {
+ if err := ep.WritePacket(&r, &hdr, vv, 123, 123); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 19314e9bd..90d65d531 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,6 +1,6 @@
package(licenses = ["notice"]) # Apache 2.0
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
go_library(
name = "ipv4",
@@ -21,3 +21,19 @@ go_library(
"//pkg/tcpip/stack",
],
)
+
+go_test(
+ name = "ipv4_test",
+ size = "small",
+ srcs = ["ipv4_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index de21e623e..74454f605 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -121,5 +121,5 @@ func sendPing4(r *stack.Route, code byte, data buffer.View) *tcpip.Error {
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
vv := buffer.NewVectorisedView(len(data), []buffer.View{data})
- return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber)
+ return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 478957827..0a2378a6a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -44,10 +44,6 @@ const (
// buckets is the number of identifier buckets.
buckets = 2048
-
- // defaultIPv4TTL is the defautl TTL for IPv4 Packets egressed by
- // Netstack.
- defaultIPv4TTL = 255
)
type address [header.IPv4AddressSize]byte
@@ -78,6 +74,11 @@ func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.Transpo
return e
}
+// DefaultTTL is the default time-to-live value for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 255
+}
+
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
@@ -106,7 +107,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + payload.Size())
id := uint32(0)
@@ -119,7 +120,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload
IHL: header.IPv4MinimumSize,
TotalLength: length,
ID: uint16(id),
- TTL: defaultIPv4TTL,
+ TTL: ttl,
Protocol: uint8(protocol),
SrcAddr: tcpip.Address(e.address[:]),
DstAddr: r.RemoteAddress,
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
new file mode 100644
index 000000000..2b7067a50
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -0,0 +1,92 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+func TestExcludeBroadcast(t *testing.T) {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+
+ const defaultMTU = 65536
+ id, _ := channel.New(256, defaultMTU, "")
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ }})
+
+ randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
+
+ var wq waiter.Queue
+ t.Run("WithoutPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ // Cannot connect using a broadcast address as the source.
+ if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ }
+
+ // However, we can bind to a broadcast address to listen.
+ if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}, nil); err != nil {
+ t.Errorf("Bind failed: %v", err)
+ }
+ })
+
+ t.Run("WithPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ // Add a valid primary endpoint address, now we can connect.
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+ if err := ep.Connect(randomAddr); err != nil {
+ t.Errorf("Connect failed: %v", err)
+ }
+ })
+}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index b18b78830..2158ba8f7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -62,7 +62,6 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
}
-// TODO: take buffer.VectorisedView by value.
func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv6MinimumSize {
@@ -107,7 +106,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
pkt[icmpV6LengthOffset] = 1
copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- r.WritePacket(&hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber)
+ r.WritePacket(&hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL())
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -131,7 +130,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
copy(pkt, h)
pkt.SetType(header.ICMPv6EchoReply)
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, *vv))
- r.WritePacket(&hdr, *vv, header.ICMPv6ProtocolNumber)
+ r.WritePacket(&hdr, *vv, header.ICMPv6ProtocolNumber, r.DefaultTTL())
case header.ICMPv6EchoReply:
if len(v) < header.ICMPv6EchoMinimumSize {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 19dc1b49e..eb89168c3 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -43,17 +43,19 @@ const (
defaultIPv6HopLimit = 255
)
-type address [header.IPv6AddressSize]byte
-
type endpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
- address address
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
}
+// DefaultTTL is the default hop limit for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 255
+}
+
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
@@ -82,14 +84,14 @@ func (e *endpoint) MaxHeaderLength() uint16 {
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
length := uint16(hdr.UsedLength() + payload.Size())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(protocol),
- HopLimit: defaultIPv6HopLimit,
- SrcAddr: tcpip.Address(e.address[:]),
+ HopLimit: ttl,
+ SrcAddr: e.id.LocalAddress,
DstAddr: r.RemoteAddress,
})
r.Stats().IP.PacketsSent.Increment()
@@ -149,15 +151,13 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
// NewEndpoint creates a new ipv6 endpoint.
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
- e := &endpoint{
+ return &endpoint{
nicid: nicid,
+ id: stack.NetworkEndpointID{LocalAddress: addr},
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
- }
- copy(e.address[:], addr)
- e.id = stack.NetworkEndpointID{LocalAddress: tcpip.Address(e.address[:])}
- return e, nil
+ }, nil
}
// SetOption implements NetworkProtocol.SetOption.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 845b40f11..61afa673e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -43,6 +43,25 @@ type NIC struct {
subnets []tcpip.Subnet
}
+// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
+type PrimaryEndpointBehavior int
+
+const (
+ // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
+ // endpoint for new connections with no local address. This is the
+ // default when calling NIC.AddAddress.
+ CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
+ // FirstPrimaryEndpoint indicates the endpoint should be the first
+ // primary endpoint considered. If there are multiple endpoints with
+ // this behavior, the most recently-added one will be first.
+ FirstPrimaryEndpoint
+
+ // NeverPrimaryEndpoint indicates the endpoint should never be a
+ // primary endpoint.
+ NeverPrimaryEndpoint
+)
+
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
return &NIC{
stack: stack,
@@ -141,6 +160,11 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
for e := list.Front(); e != nil; e = e.Next() {
r := e.(*referencedNetworkEndpoint)
+ // TODO: allow broadcast address when SO_BROADCAST is set.
+ switch r.ep.ID().LocalAddress {
+ case header.IPv4Broadcast, header.IPv4Any:
+ continue
+ }
if r.tryIncRef() {
return r
}
@@ -150,7 +174,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
}
// findEndpoint finds the endpoint, if any, with the given address.
-func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address) *referencedNetworkEndpoint {
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
@@ -171,7 +195,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
n.mu.Lock()
ref = n.endpoints[id]
if ref == nil || !ref.tryIncRef() {
- ref, _ = n.addAddressLocked(protocol, address, true)
+ ref, _ = n.addAddressLocked(protocol, address, peb, true)
if ref != nil {
ref.holdsInsertRef = false
}
@@ -180,7 +204,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
return ref
}
-func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
@@ -224,7 +248,12 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.
n.primary[protocol] = l
}
- l.PushBack(ref)
+ switch peb {
+ case CanBePrimaryEndpoint:
+ l.PushBack(ref)
+ case FirstPrimaryEndpoint:
+ l.PushFront(ref)
+ }
return ref, nil
}
@@ -232,9 +261,15 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.
// AddAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint)
+}
+
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
+func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocol, addr, false)
+ _, err := n.addAddressLocked(protocol, addr, peb, false)
n.mu.Unlock()
return err
@@ -319,7 +354,11 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
}
delete(n.endpoints, id)
- n.primary[r.protocol].Remove(r)
+ wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
+ if wasInList {
+ n.primary[r.protocol].Remove(r)
+ }
+
r.ep.Close()
}
@@ -398,7 +437,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.Lin
ref, ok = n.endpoints[id]
if !ok || !ref.tryIncRef() {
var err *tcpip.Error
- ref, err = n.addAddressLocked(protocol, dst, true)
+ ref, err = n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
if err == nil {
ref.holdsInsertRef = false
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index b9e2cc045..7afde3598 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -124,6 +124,10 @@ type TransportDispatcher interface {
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
+ // DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
+ // for this endpoint.
+ DefaultTTL() uint8
+
// MTU is the maximum transmission unit for this endpoint. This is
// generally calculated as the MTU of the underlying data link endpoint
// minus the network endpoint max header length.
@@ -141,7 +145,7 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 533a0b560..9dfb95b4a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -129,14 +129,19 @@ func (r *Route) IsResolutionRequired() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
- err := r.ref.ep.WritePacket(r, hdr, payload, protocol)
+func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl)
if err == tcpip.ErrNoRoute {
r.Stats().IP.OutgoingPacketErrors.Increment()
}
return err
}
+// DefaultTTL returns the default TTL of the underlying network endpoint.
+func (r *Route) DefaultTTL() uint8 {
+ return r.ref.ep.DefaultTTL()
+}
+
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
return r.ref.ep.MTU()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 675ccc6fa..789c819dd 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -607,6 +607,12 @@ type NICStateFlags struct {
// AddAddress adds a new network-layer address to the specified NIC.
func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
+}
+
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
+func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -615,7 +621,7 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return tcpip.ErrUnknownNICID
}
- return nic.AddAddress(protocol, addr)
+ return nic.AddAddressWithOptions(protocol, addr, peb)
}
// AddSubnet adds a subnet range to the specified NIC.
@@ -703,7 +709,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
var ref *referencedNetworkEndpoint
if len(localAddr) != 0 {
- ref = nic.findEndpoint(netProto, localAddr)
+ ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
} else {
ref = nic.primaryEndpoint(netProto)
}
@@ -746,7 +752,7 @@ func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProto
return 0
}
- ref := nic.findEndpoint(protocol, addr)
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
if ref == nil {
return 0
}
@@ -758,7 +764,7 @@ func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProto
// Go through all the NICs.
for _, nic := range s.nics {
- ref := nic.findEndpoint(protocol, addr)
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
if ref != nil {
ref.decRef()
return nic.id
@@ -926,3 +932,14 @@ func (s *Stack) RemoveTCPProbe() {
s.tcpProbeFunc = nil
s.mu.Unlock()
}
+
+// JoinGroup joins the given multicast group on the given NIC.
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ // TODO: notify network of subscription via igmp protocol.
+ return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint)
+}
+
+// LeaveGroup leaves the given multicast group on the given NIC.
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ return s.RemoveAddress(nicID, multicastAddr)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 02f02dfa2..816707d27 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -65,6 +65,10 @@ func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
return f.nicid
}
+func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
+}
+
func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
@@ -105,7 +109,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.linkEP.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -269,7 +273,7 @@ func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) {
defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(&hdr, buffer.VectorisedView{}, fakeTransNumber); err != nil {
+ if err := r.WritePacket(&hdr, buffer.VectorisedView{}, fakeTransNumber, 123); err != nil {
t.Errorf("WritePacket failed: %v", err)
return
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5ab485c98..e4607192f 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -71,7 +71,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions)
return 0, err
}
vv := buffer.NewVectorisedView(len(v), []buffer.View{v})
- if err := f.route.WritePacket(&hdr, vv, fakeTransNumber); err != nil {
+ if err := f.route.WritePacket(&hdr, vv, fakeTransNumber, 123); err != nil {
return 0, err
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4065fc30f..51360b11f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -453,6 +453,28 @@ type KeepaliveIntervalOption time.Duration
// closed.
type KeepaliveCountOption int
+// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
+// TTL value for multicast messages. The default is 1.
+type MulticastTTLOption uint8
+
+// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
+// AddMembershipOption and RemoveMembershipOption.
+type MembershipOption struct {
+ NIC NICID
+ InterfaceAddr Address
+ MulticastAddr Address
+}
+
+// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type AddMembershipOption MembershipOption
+
+// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type RemoveMembershipOption MembershipOption
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination adddress in the row.
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go
index fc98c41eb..7aaf2d9c6 100644
--- a/pkg/tcpip/transport/ping/endpoint.go
+++ b/pkg/tcpip/transport/ping/endpoint.go
@@ -385,7 +385,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
vv := buffer.NewVectorisedView(len(data), []buffer.View{data})
- return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber)
+ return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
@@ -408,8 +408,9 @@ func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv6.SetChecksum(0)
icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+
vv := buffer.NewVectorisedView(len(data), []buffer.View{data})
- return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber)
+ return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL())
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index de5f963cf..ce87d5818 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -563,14 +563,14 @@ func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, a
}
options := makeSynOptions(opts)
- err := sendTCPWithOptions(r, id, buffer.VectorisedView{}, flags, seq, ack, rcvWnd, options)
+ err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
putOptions(options)
return err
}
-// sendTCPWithOptions sends a TCP segment with the provided options via the
-// provided network endpoint and under the provided identity.
-func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
@@ -608,48 +608,7 @@ func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffe
r.Stats().TCP.ResetsSent.Increment()
}
- return r.WritePacket(&hdr, data, ProtocolNumber)
-}
-
-// sendTCP sends a TCP segment via the provided network endpoint and under the
-// provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, payload buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()))
-
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
- }
-
- // Initialize the header.
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
- tcp.Encode(&header.TCPFields{
- SrcPort: id.LocalPort,
- DstPort: id.RemotePort,
- SeqNum: uint32(seq),
- AckNum: uint32(ack),
- DataOffset: header.TCPMinimumSize,
- Flags: flags,
- WindowSize: uint16(rcvWnd),
- })
-
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
- length := uint16(hdr.UsedLength() + payload.Size())
- xsum := r.PseudoHeaderChecksum(ProtocolNumber)
- for _, v := range payload.Views() {
- xsum = header.Checksum(v, xsum)
- }
-
- tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
- }
-
- r.Stats().TCP.SegmentsSent.Increment()
- if (flags & flagRst) != 0 {
- r.Stats().TCP.ResetsSent.Increment()
- }
-
- return r.WritePacket(&hdr, payload, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
// makeOptions makes an options slice.
@@ -698,12 +657,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- if len(options) > 0 {
- err := sendTCPWithOptions(&e.route, e.id, data, flags, seq, ack, rcvWnd, options)
- putOptions(options)
- return err
- }
- err := sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd)
+ err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options)
putOptions(options)
return err
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 006b2f074..fe21f2c78 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -147,7 +147,7 @@ func replyWithReset(s *segment) {
ack := s.sequenceNumber.Add(s.logicalLen())
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, flagRst|flagAck, seq, ack, 0)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0, nil)
}
// SetOption implements TransportProtocol.SetOption.
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index d091a6196..5de518a55 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -70,19 +70,24 @@ type endpoint struct {
rcvTimestamp bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- id stack.TransportEndpointID
- state endpointState
- bindNICID tcpip.NICID
- regNICID tcpip.NICID
- route stack.Route `state:"manual"`
- dstPort uint16
- v6only bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+ multicastTTL uint8
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
+ multicastMemberships []multicastMembership
+
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
@@ -92,11 +97,29 @@ type endpoint struct {
effectiveNetProtos []tcpip.NetworkProtocolNumber
}
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
- stack: stack,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
@@ -135,6 +158,11 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -329,8 +357,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
return 0, err
}
+ ttl := route.DefaultTTL()
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ ttl = e.multicastTTL
+ }
+
vv := buffer.NewVectorisedView(len(v), []buffer.View{v})
- if err := sendUDP(route, vv, e.id.LocalPort, dstPort); err != nil {
+ if err := sendUDP(route, vv, e.id.LocalPort, dstPort, ttl); err != nil {
return 0, err
}
return uintptr(len(v)), nil
@@ -365,6 +398,56 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.rcvMu.Lock()
e.rcvTimestamp = v != 0
e.rcvMu.Unlock()
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.multicastTTL = uint8(v)
+
+ case tcpip.AddMembershipOption:
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrNoRoute
+ }
+
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+
+ case tcpip.RemoveMembershipOption:
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrNoRoute
+ }
+
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for i, mem := range e.multicastMemberships {
+ if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
+ // Only remove the first match, so that each added membership above is
+ // paired with exactly 1 removal.
+ e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ break
+ }
+ }
}
return nil
}
@@ -421,6 +504,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 1
}
e.rcvMu.Unlock()
+
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
}
return tcpip.ErrUnknownProtocolOption
@@ -428,7 +517,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -454,7 +543,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
- return r.WritePacket(&hdr, data, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
@@ -581,7 +670,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != stateBound && e.state != stateConnected {
return tcpip.ErrNotConnected
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4700193c2..6d7a737bd 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -34,16 +34,20 @@ import (
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
+ testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
+ multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+ multicastAddr = "\xe8\x2b\xd3\xea"
+ multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ multicastPort = 1234
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -128,37 +132,35 @@ func (c *testContext) createV6Endpoint(v6only bool) {
}
}
-func (c *testContext) getV6Packet() []byte {
+func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
select {
case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ if p.Proto != protocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr))
- return b
-
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
- }
-
- return nil
-}
-
-func (c *testContext) getPacket() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
+ var srcAddr, dstAddr tcpip.Address
+ switch protocolNumber {
+ case ipv4.ProtocolNumber:
+ checkerFn = checker.IPv4
+ srcAddr, dstAddr = stackAddr, testAddr
+ if multicast {
+ dstAddr = multicastAddr
+ }
+ case ipv6.ProtocolNumber:
+ checkerFn = checker.IPv6
+ srcAddr, dstAddr = stackV6Addr, testV6Addr
+ if multicast {
+ dstAddr = multicastV6Addr
+ }
+ default:
+ c.t.Fatalf("unknown protocol %d", protocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
-
- checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr))
+ checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
return b
case <-time.After(2 * time.Second):
@@ -495,7 +497,7 @@ func testV4Write(c *testContext) uint16 {
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -525,7 +527,7 @@ func testV6Write(c *testContext) uint16 {
}
// Check that we received the packet.
- b := c.getV6Packet()
+ b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
@@ -682,7 +684,7 @@ func TestV6WriteOnConnected(t *testing.T) {
}
// Check that we received the packet.
- b := c.getV6Packet()
+ b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
@@ -718,7 +720,7 @@ func TestV4WriteOnConnected(t *testing.T) {
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -769,3 +771,162 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
}
}
+
+func TestTTL(t *testing.T) {
+ payload := tcpip.SlicePayload(buffer.View(newPayload()))
+
+ for _, name := range []string{"v4", "v6", "dual"} {
+ t.Run(name, func(t *testing.T) {
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch name {
+ case "v4":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6", "dual":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var variants []string
+ switch name {
+ case "v4":
+ variants = []string{"v4"}
+ case "v6":
+ variants = []string{"v6"}
+ case "dual":
+ variants = []string{"v6", "mapped"}
+ }
+
+ for _, variant := range variants {
+ t.Run(variant, func(t *testing.T) {
+ for _, typ := range []string{"unicast", "multicast"} {
+ t.Run(typ, func(t *testing.T) {
+ var addr tcpip.Address
+ var port uint16
+ switch typ {
+ case "unicast":
+ port = testPort
+ switch variant {
+ case "v4":
+ addr = testAddr
+ case "mapped":
+ addr = testV4MappedAddr
+ case "v6":
+ addr = testV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ port = multicastPort
+ switch variant {
+ case "v4":
+ addr = multicastAddr
+ case "mapped":
+ addr = multicastV4MappedAddr
+ case "v6":
+ addr = multicastV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ switch name {
+ case "v4":
+ case "v6":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ case "dual":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
+ }
+
+ checkerFn := checker.IPv4
+ switch variant {
+ case "v4", "mapped":
+ case "v6":
+ checkerFn = checker.IPv6
+ default:
+ t.Fatal("unknown test variant")
+ }
+ var wantTTL uint8
+ var multicast bool
+ switch typ {
+ case "unicast":
+ multicast = false
+ switch variant {
+ case "v4", "mapped":
+ ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ case "v6":
+ ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ wantTTL = multicastTTL
+ multicast = true
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch variant {
+ case "v4", "mapped":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ b := c.getPacket(networkProtocolNumber, multicast)
+ checkerFn(c.t, b,
+ checker.TTL(wantTTL),
+ checker.UDP(
+ checker.DstPort(port),
+ ),
+ )
+ })
+ }
+ })
+ }
+ })
+ }
+}