diff options
author | Tamir Duberstein <tamird@google.com> | 2018-09-12 20:38:27 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-09-12 20:39:24 -0700 |
commit | 5adb3468d4de249df055d641e01ce6582b3a9388 (patch) | |
tree | fa75f573912b3647dcc7158961aa1085e572a8a1 | |
parent | 9dec7a3db99d8c7045324bc6d8f0c27e88407f6c (diff) |
Add multicast support
PiperOrigin-RevId: 212750821
Change-Id: I822fd63e48c684b45fd91f9ce057867b7eceb792
-rw-r--r-- | pkg/dhcp/client.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/checker/checker.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv6.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 92 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/transport/ping/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 58 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 123 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 235 |
24 files changed, 605 insertions, 153 deletions
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go index 909040e79..cf8472c5f 100644 --- a/pkg/dhcp/client.go +++ b/pkg/dhcp/client.go @@ -119,10 +119,10 @@ func (c *Client) Config() Config { // If the server sets a lease limit a timer is set to automatically // renew it. func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) { - if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress { + if err := c.stack.AddAddressWithOptions(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff", stack.NeverPrimaryEndpoint); err != nil && err != tcpip.ErrDuplicateAddress { return Config{}, fmt.Errorf("dhcp: %v", err) } - if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress { + if err := c.stack.AddAddressWithOptions(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00", stack.NeverPrimaryEndpoint); err != nil && err != tcpip.ErrDuplicateAddress { return Config{}, fmt.Errorf("dhcp: %v", err) } defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff") @@ -237,7 +237,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg // DHCPREQUEST addr := tcpip.Address(h.yiaddr()) - if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil { + if err := c.stack.AddAddressWithOptions(c.nicid, ipv4.ProtocolNumber, addr, stack.FirstPrimaryEndpoint); err != nil { if err != tcpip.ErrDuplicateAddress { return Config{}, fmt.Errorf("adding address: %v", err) } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 518719de4..8e0e49efa 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -86,6 +86,22 @@ func DstAddr(addr tcpip.Address) NetworkChecker { } } +// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). +func TTL(ttl uint8) NetworkChecker { + return func(t *testing.T, h []header.Network) { + var v uint8 + switch ip := h[0].(type) { + case header.IPv4: + v = ip.TTL() + case header.IPv6: + v = ip.HopLimit() + } + if v != ttl { + t.Fatalf("Bad TTL, got %v, want %v", v, ttl) + } + } +} + // PayloadLen creates a checker that checks the payload length. func PayloadLen(plen int) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 950c54f74..29570cc34 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -93,6 +93,12 @@ const ( // IPv4Version is the version of the ipv4 protocol. IPv4Version = 4 + + // IPv4Broadcast is the broadcast address of the IPv4 procotol. + IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" + + // IPv4Any is the non-routable IPv4 "any" meta address. + IPv4Any tcpip.Address = "\x00\x00\x00\x00" ) // Flags that may be set in an IPv4 packet. @@ -259,3 +265,13 @@ func (b IPv4) IsValid(pktSize int) bool { return true } + +// IsV4MulticastAddress determines if the provided address is an IPv4 multicast +// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits +// will be 1110 = 0xe0. +func IsV4MulticastAddress(addr tcpip.Address) bool { + if len(addr) != IPv4AddressSize { + return false + } + return (addr[0] & 0xf0) == 0xe0 +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 62bc23c27..066e48a9a 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -193,3 +193,12 @@ func IsV4MappedAddress(addr tcpip.Address) bool { return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff") } + +// IsV6MulticastAddress determines if the provided address is an IPv6 +// multicast address (anything starting with FF). +func IsV6MulticastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + return addr[0] == 0xff +} diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6bf4be868..8f64e3f42 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -51,6 +51,11 @@ type endpoint struct { linkAddrCache stack.LinkAddressCache } +// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. +func (e *endpoint) DefaultTTL() uint8 { + return 0 +} + func (e *endpoint) MTU() uint32 { lmtu := e.linkEP.MTU() return lmtu - uint32(e.MaxHeaderLength()) @@ -74,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 1e92b7ae9..0cb53fb42 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -222,7 +222,7 @@ func TestIPv4Send(t *testing.T) { t.Fatalf("could not find route: %v", err) } vv := buffer.NewVectorisedView(len(payload), []buffer.View{payload}) - if err := ep.WritePacket(&r, &hdr, vv, 123); err != nil { + if err := ep.WritePacket(&r, &hdr, vv, 123, 123); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -461,7 +461,7 @@ func TestIPv6Send(t *testing.T) { t.Fatalf("could not find route: %v", err) } vv := buffer.NewVectorisedView(len(payload), []buffer.View{payload}) - if err := ep.WritePacket(&r, &hdr, vv, 123); err != nil { + if err := ep.WritePacket(&r, &hdr, vv, 123, 123); err != nil { t.Fatalf("WritePacket failed: %v", err) } } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 19314e9bd..90d65d531 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,6 +1,6 @@ package(licenses = ["notice"]) # Apache 2.0 -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools/go_stateify:defs.bzl", "go_library", "go_test") go_library( name = "ipv4", @@ -21,3 +21,19 @@ go_library( "//pkg/tcpip/stack", ], ) + +go_test( + name = "ipv4_test", + size = "small", + srcs = ["ipv4_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index de21e623e..74454f605 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -121,5 +121,5 @@ func sendPing4(r *stack.Route, code byte, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber) + return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 478957827..0a2378a6a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -44,10 +44,6 @@ const ( // buckets is the number of identifier buckets. buckets = 2048 - - // defaultIPv4TTL is the defautl TTL for IPv4 Packets egressed by - // Netstack. - defaultIPv4TTL = 255 ) type address [header.IPv4AddressSize]byte @@ -78,6 +74,11 @@ func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.Transpo return e } +// DefaultTTL is the default time-to-live value for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return 255 +} + // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { @@ -106,7 +107,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payload.Size()) id := uint32(0) @@ -119,7 +120,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload IHL: header.IPv4MinimumSize, TotalLength: length, ID: uint16(id), - TTL: defaultIPv4TTL, + TTL: ttl, Protocol: uint8(protocol), SrcAddr: tcpip.Address(e.address[:]), DstAddr: r.RemoteAddress, diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go new file mode 100644 index 000000000..2b7067a50 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -0,0 +1,92 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +func TestExcludeBroadcast(t *testing.T) { + s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + + const defaultMTU = 65536 + id, _ := channel.New(256, defaultMTU, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }}) + + randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} + + var wq waiter.Queue + t.Run("WithoutPrimaryAddress", func(t *testing.T) { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + // Cannot connect using a broadcast address as the source. + if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + } + + // However, we can bind to a broadcast address to listen. + if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}, nil); err != nil { + t.Errorf("Bind failed: %v", err) + } + }) + + t.Run("WithPrimaryAddress", func(t *testing.T) { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + // Add a valid primary endpoint address, now we can connect. + if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + if err := ep.Connect(randomAddr); err != nil { + t.Errorf("Connect failed: %v", err) + } + }) +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b18b78830..2158ba8f7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -62,7 +62,6 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) } -// TODO: take buffer.VectorisedView by value. func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { v := vv.First() if len(v) < header.ICMPv6MinimumSize { @@ -107,7 +106,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { pkt[icmpV6LengthOffset] = 1 copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:]) pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - r.WritePacket(&hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber) + r.WritePacket(&hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()) e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) @@ -131,7 +130,7 @@ func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, *vv)) - r.WritePacket(&hdr, *vv, header.ICMPv6ProtocolNumber) + r.WritePacket(&hdr, *vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) case header.ICMPv6EchoReply: if len(v) < header.ICMPv6EchoMinimumSize { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 19dc1b49e..eb89168c3 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -43,17 +43,19 @@ const ( defaultIPv6HopLimit = 255 ) -type address [header.IPv6AddressSize]byte - type endpoint struct { nicid tcpip.NICID id stack.NetworkEndpointID - address address linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher } +// DefaultTTL is the default hop limit for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return 255 +} + // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { @@ -82,14 +84,14 @@ func (e *endpoint) MaxHeaderLength() uint16 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { length := uint16(hdr.UsedLength() + payload.Size()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(protocol), - HopLimit: defaultIPv6HopLimit, - SrcAddr: tcpip.Address(e.address[:]), + HopLimit: ttl, + SrcAddr: e.id.LocalAddress, DstAddr: r.RemoteAddress, }) r.Stats().IP.PacketsSent.Increment() @@ -149,15 +151,13 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { // NewEndpoint creates a new ipv6 endpoint. func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { - e := &endpoint{ + return &endpoint{ nicid: nicid, + id: stack.NetworkEndpointID{LocalAddress: addr}, linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, - } - copy(e.address[:], addr) - e.id = stack.NetworkEndpointID{LocalAddress: tcpip.Address(e.address[:])} - return e, nil + }, nil } // SetOption implements NetworkProtocol.SetOption. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 845b40f11..61afa673e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -43,6 +43,25 @@ type NIC struct { subnets []tcpip.Subnet } +// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior. +type PrimaryEndpointBehavior int + +const ( + // CanBePrimaryEndpoint indicates the endpoint can be used as a primary + // endpoint for new connections with no local address. This is the + // default when calling NIC.AddAddress. + CanBePrimaryEndpoint PrimaryEndpointBehavior = iota + + // FirstPrimaryEndpoint indicates the endpoint should be the first + // primary endpoint considered. If there are multiple endpoints with + // this behavior, the most recently-added one will be first. + FirstPrimaryEndpoint + + // NeverPrimaryEndpoint indicates the endpoint should never be a + // primary endpoint. + NeverPrimaryEndpoint +) + func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { return &NIC{ stack: stack, @@ -141,6 +160,11 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN for e := list.Front(); e != nil; e = e.Next() { r := e.(*referencedNetworkEndpoint) + // TODO: allow broadcast address when SO_BROADCAST is set. + switch r.ep.ID().LocalAddress { + case header.IPv4Broadcast, header.IPv4Any: + continue + } if r.tryIncRef() { return r } @@ -150,7 +174,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN } // findEndpoint finds the endpoint, if any, with the given address. -func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address) *referencedNetworkEndpoint { +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() @@ -171,7 +195,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A n.mu.Lock() ref = n.endpoints[id] if ref == nil || !ref.tryIncRef() { - ref, _ = n.addAddressLocked(protocol, address, true) + ref, _ = n.addAddressLocked(protocol, address, peb, true) if ref != nil { ref.holdsInsertRef = false } @@ -180,7 +204,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A return ref } -func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { netProto, ok := n.stack.networkProtocols[protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -224,7 +248,12 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip. n.primary[protocol] = l } - l.PushBack(ref) + switch peb { + case CanBePrimaryEndpoint: + l.PushBack(ref) + case FirstPrimaryEndpoint: + l.PushFront(ref) + } return ref, nil } @@ -232,9 +261,15 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip. // AddAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint) +} + +// AddAddressWithOptions is the same as AddAddress, but allows you to specify +// whether the new endpoint can be primary or not. +func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocol, addr, false) + _, err := n.addAddressLocked(protocol, addr, peb, false) n.mu.Unlock() return err @@ -319,7 +354,11 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { } delete(n.endpoints, id) - n.primary[r.protocol].Remove(r) + wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front() + if wasInList { + n.primary[r.protocol].Remove(r) + } + r.ep.Close() } @@ -398,7 +437,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.Lin ref, ok = n.endpoints[id] if !ok || !ref.tryIncRef() { var err *tcpip.Error - ref, err = n.addAddressLocked(protocol, dst, true) + ref, err = n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true) if err == nil { ref.holdsInsertRef = false } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b9e2cc045..7afde3598 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -124,6 +124,10 @@ type TransportDispatcher interface { // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) + // for this endpoint. + DefaultTTL() uint8 + // MTU is the maximum transmission unit for this endpoint. This is // generally calculated as the MTU of the underlying data link endpoint // minus the network endpoint max header length. @@ -141,7 +145,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. - WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error + WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 533a0b560..9dfb95b4a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -129,14 +129,19 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error { - err := r.ref.ep.WritePacket(r, hdr, payload, protocol) +func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl) if err == tcpip.ErrNoRoute { r.Stats().IP.OutgoingPacketErrors.Increment() } return err } +// DefaultTTL returns the default TTL of the underlying network endpoint. +func (r *Route) DefaultTTL() uint8 { + return r.ref.ep.DefaultTTL() +} + // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { return r.ref.ep.MTU() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 675ccc6fa..789c819dd 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -607,6 +607,12 @@ type NICStateFlags struct { // AddAddress adds a new network-layer address to the specified NIC. func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) +} + +// AddAddressWithOptions is the same as AddAddress, but allows you to specify +// whether the new endpoint can be primary or not. +func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -615,7 +621,7 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return tcpip.ErrUnknownNICID } - return nic.AddAddress(protocol, addr) + return nic.AddAddressWithOptions(protocol, addr, peb) } // AddSubnet adds a subnet range to the specified NIC. @@ -703,7 +709,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n var ref *referencedNetworkEndpoint if len(localAddr) != 0 { - ref = nic.findEndpoint(netProto, localAddr) + ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } else { ref = nic.primaryEndpoint(netProto) } @@ -746,7 +752,7 @@ func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProto return 0 } - ref := nic.findEndpoint(protocol, addr) + ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) if ref == nil { return 0 } @@ -758,7 +764,7 @@ func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProto // Go through all the NICs. for _, nic := range s.nics { - ref := nic.findEndpoint(protocol, addr) + ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) if ref != nil { ref.decRef() return nic.id @@ -926,3 +932,14 @@ func (s *Stack) RemoveTCPProbe() { s.tcpProbeFunc = nil s.mu.Unlock() } + +// JoinGroup joins the given multicast group on the given NIC. +func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { + // TODO: notify network of subscription via igmp protocol. + return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint) +} + +// LeaveGroup leaves the given multicast group on the given NIC. +func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { + return s.RemoveAddress(nicID, multicastAddr) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 02f02dfa2..816707d27 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -65,6 +65,10 @@ func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { return f.nicid } +func (*fakeNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } @@ -105,7 +109,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.linkEP.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -269,7 +273,7 @@ func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) { defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(&hdr, buffer.VectorisedView{}, fakeTransNumber); err != nil { + if err := r.WritePacket(&hdr, buffer.VectorisedView{}, fakeTransNumber, 123); err != nil { t.Errorf("WritePacket failed: %v", err) return } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5ab485c98..e4607192f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -71,7 +71,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) return 0, err } vv := buffer.NewVectorisedView(len(v), []buffer.View{v}) - if err := f.route.WritePacket(&hdr, vv, fakeTransNumber); err != nil { + if err := f.route.WritePacket(&hdr, vv, fakeTransNumber, 123); err != nil { return 0, err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4065fc30f..51360b11f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -453,6 +453,28 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default +// TTL value for multicast messages. The default is 1. +type MulticastTTLOption uint8 + +// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to +// AddMembershipOption and RemoveMembershipOption. +type MembershipOption struct { + NIC NICID + InterfaceAddr Address + MulticastAddr Address +} + +// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast +// group identified by the given multicast address, on the interface matching +// the given interface address. +type AddMembershipOption MembershipOption + +// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast +// group identified by the given multicast address, on the interface matching +// the given interface address. +type RemoveMembershipOption MembershipOption + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination adddress in the row. diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index fc98c41eb..7aaf2d9c6 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -385,7 +385,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber) + return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()) } func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { @@ -408,8 +408,9 @@ func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber) + return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index de5f963cf..ce87d5818 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -563,14 +563,14 @@ func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, a } options := makeSynOptions(opts) - err := sendTCPWithOptions(r, id, buffer.VectorisedView{}, flags, seq, ack, rcvWnd, options) + err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options) putOptions(options) return err } -// sendTCPWithOptions sends a TCP segment with the provided options via the -// provided network endpoint and under the provided identity. -func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { optLen := len(opts) // Allocate a buffer for the TCP header. hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) @@ -608,48 +608,7 @@ func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffe r.Stats().TCP.ResetsSent.Increment() } - return r.WritePacket(&hdr, data, ProtocolNumber) -} - -// sendTCP sends a TCP segment via the provided network endpoint and under the -// provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, payload buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength())) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff - } - - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) - tcp.Encode(&header.TCPFields{ - SrcPort: id.LocalPort, - DstPort: id.RemotePort, - SeqNum: uint32(seq), - AckNum: uint32(ack), - DataOffset: header.TCPMinimumSize, - Flags: flags, - WindowSize: uint16(rcvWnd), - }) - - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { - length := uint16(hdr.UsedLength() + payload.Size()) - xsum := r.PseudoHeaderChecksum(ProtocolNumber) - for _, v := range payload.Views() { - xsum = header.Checksum(v, xsum) - } - - tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) - } - - r.Stats().TCP.SegmentsSent.Increment() - if (flags & flagRst) != 0 { - r.Stats().TCP.ResetsSent.Increment() - } - - return r.WritePacket(&hdr, payload, ProtocolNumber) + return r.WritePacket(&hdr, data, ProtocolNumber, ttl) } // makeOptions makes an options slice. @@ -698,12 +657,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - if len(options) > 0 { - err := sendTCPWithOptions(&e.route, e.id, data, flags, seq, ack, rcvWnd, options) - putOptions(options) - return err - } - err := sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd) + err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options) putOptions(options) return err } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 006b2f074..fe21f2c78 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -147,7 +147,7 @@ func replyWithReset(s *segment) { ack := s.sequenceNumber.Add(s.logicalLen()) - sendTCP(&s.route, s.id, buffer.VectorisedView{}, flagRst|flagAck, seq, ack, 0) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0, nil) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d091a6196..5de518a55 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -70,19 +70,24 @@ type endpoint struct { rcvTimestamp bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - id stack.TransportEndpointID - state endpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID - route stack.Route `state:"manual"` - dstPort uint16 - v6only bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + regNICID tcpip.NICID + route stack.Route `state:"manual"` + dstPort uint16 + v6only bool + multicastTTL uint8 // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags + // multicastMemberships that need to be remvoed when the endpoint is + // closed. Protected by the mu mutex. + multicastMemberships []multicastMembership + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple @@ -92,11 +97,29 @@ type endpoint struct { effectiveNetProtos []tcpip.NetworkProtocolNumber } +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { return &endpoint{ - stack: stack, - netProto: netProto, - waiterQueue: waiterQueue, + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + // RFC 1075 section 5.4 recommends a TTL of 1 for membership + // requests. + // + // RFC 5135 4.2.1 appears to assume that IGMP messages have a + // TTL of 1. + // + // RFC 5135 Appendix A defines TTL=1: A multicast source that + // wants its traffic to not traverse a router (e.g., leave a + // home network) may find it useful to send traffic with IP + // TTL=1. + // + // Linux defaults to TTL=1. + multicastTTL: 1, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, } @@ -135,6 +158,11 @@ func (e *endpoint) Close() { e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } + for _, mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + // Close the receive list and drain it. e.rcvMu.Lock() e.rcvClosed = true @@ -329,8 +357,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc return 0, err } + ttl := route.DefaultTTL() + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { + ttl = e.multicastTTL + } + vv := buffer.NewVectorisedView(len(v), []buffer.View{v}) - if err := sendUDP(route, vv, e.id.LocalPort, dstPort); err != nil { + if err := sendUDP(route, vv, e.id.LocalPort, dstPort, ttl); err != nil { return 0, err } return uintptr(len(v)), nil @@ -365,6 +398,56 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.rcvMu.Lock() e.rcvTimestamp = v != 0 e.rcvMu.Unlock() + + case tcpip.MulticastTTLOption: + e.mu.Lock() + defer e.mu.Unlock() + e.multicastTTL = uint8(v) + + case tcpip.AddMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + + e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) + + case tcpip.RemoveMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + for i, mem := range e.multicastMemberships { + if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr { + // Only remove the first match, so that each added membership above is + // paired with exactly 1 removal. + e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + break + } + } } return nil } @@ -421,6 +504,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 1 } e.rcvMu.Unlock() + + case *tcpip.MulticastTTLOption: + e.mu.Lock() + *o = tcpip.MulticastTTLOption(e.multicastTTL) + e.mu.Unlock() + return nil } return tcpip.ErrUnknownProtocolOption @@ -428,7 +517,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -454,7 +543,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - return r.WritePacket(&hdr, data, ProtocolNumber) + return r.WritePacket(&hdr, data, ProtocolNumber, ttl) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { @@ -581,7 +670,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.state != stateConnected { + // A socket in the bound state can still receive multicast messages, + // so we need to notify waiters on shutdown. + if e.state != stateBound && e.state != stateConnected { return tcpip.ErrNotConnected } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4700193c2..6d7a737bd 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -34,16 +34,20 @@ import ( ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr + testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr + multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testAddr = "\x0a\x00\x00\x02" + testPort = 4096 + multicastAddr = "\xe8\x2b\xd3\xea" + multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + multicastPort = 1234 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -128,37 +132,35 @@ func (c *testContext) createV6Endpoint(v6only bool) { } } -func (c *testContext) getV6Packet() []byte { +func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { select { case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + if p.Proto != protocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -func (c *testContext) getPacket() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) + var srcAddr, dstAddr tcpip.Address + switch protocolNumber { + case ipv4.ProtocolNumber: + checkerFn = checker.IPv4 + srcAddr, dstAddr = stackAddr, testAddr + if multicast { + dstAddr = multicastAddr + } + case ipv6.ProtocolNumber: + checkerFn = checker.IPv6 + srcAddr, dstAddr = stackV6Addr, testV6Addr + if multicast { + dstAddr = multicastV6Addr + } + default: + c.t.Fatalf("unknown protocol %d", protocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr)) + checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) return b case <-time.After(2 * time.Second): @@ -495,7 +497,7 @@ func testV4Write(c *testContext) uint16 { } // Check that we received the packet. - b := c.getPacket() + b := c.getPacket(ipv4.ProtocolNumber, false) udp := header.UDP(header.IPv4(b).Payload()) checker.IPv4(c.t, b, checker.UDP( @@ -525,7 +527,7 @@ func testV6Write(c *testContext) uint16 { } // Check that we received the packet. - b := c.getV6Packet() + b := c.getPacket(ipv6.ProtocolNumber, false) udp := header.UDP(header.IPv6(b).Payload()) checker.IPv6(c.t, b, checker.UDP( @@ -682,7 +684,7 @@ func TestV6WriteOnConnected(t *testing.T) { } // Check that we received the packet. - b := c.getV6Packet() + b := c.getPacket(ipv6.ProtocolNumber, false) udp := header.UDP(header.IPv6(b).Payload()) checker.IPv6(c.t, b, checker.UDP( @@ -718,7 +720,7 @@ func TestV4WriteOnConnected(t *testing.T) { } // Check that we received the packet. - b := c.getPacket() + b := c.getPacket(ipv4.ProtocolNumber, false) udp := header.UDP(header.IPv4(b).Payload()) checker.IPv4(c.t, b, checker.UDP( @@ -769,3 +771,162 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) } } + +func TestTTL(t *testing.T) { + payload := tcpip.SlicePayload(buffer.View(newPayload())) + + for _, name := range []string{"v4", "v6", "dual"} { + t.Run(name, func(t *testing.T) { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch name { + case "v4": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + var variants []string + switch name { + case "v4": + variants = []string{"v4"} + case "v6": + variants = []string{"v6"} + case "dual": + variants = []string{"v6", "mapped"} + } + + for _, variant := range variants { + t.Run(variant, func(t *testing.T) { + for _, typ := range []string{"unicast", "multicast"} { + t.Run(typ, func(t *testing.T) { + var addr tcpip.Address + var port uint16 + switch typ { + case "unicast": + port = testPort + switch variant { + case "v4": + addr = testAddr + case "mapped": + addr = testV4MappedAddr + case "v6": + addr = testV6Addr + default: + t.Fatal("unknown test variant") + } + case "multicast": + port = multicastPort + switch variant { + case "v4": + addr = multicastAddr + case "mapped": + addr = multicastV4MappedAddr + case "v6": + addr = multicastV6Addr + default: + t.Fatal("unknown test variant") + } + default: + t.Fatal("unknown test variant") + } + + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + switch name { + case "v4": + case "v6": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + case "dual": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + default: + t.Fatal("unknown test variant") + } + + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) + } + + checkerFn := checker.IPv4 + switch variant { + case "v4", "mapped": + case "v6": + checkerFn = checker.IPv6 + default: + t.Fatal("unknown test variant") + } + var wantTTL uint8 + var multicast bool + switch typ { + case "unicast": + multicast = false + switch variant { + case "v4", "mapped": + ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + case "v6": + ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + default: + t.Fatal("unknown test variant") + } + case "multicast": + wantTTL = multicastTTL + multicast = true + default: + t.Fatal("unknown test variant") + } + + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch variant { + case "v4", "mapped": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + b := c.getPacket(networkProtocolNumber, multicast) + checkerFn(c.t, b, + checker.TTL(wantTTL), + checker.UDP( + checker.DstPort(port), + ), + ) + }) + } + }) + } + }) + } +} |