summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2018-09-12 20:38:27 -0700
committerShentubot <shentubot@google.com>2018-09-12 20:39:24 -0700
commit5adb3468d4de249df055d641e01ce6582b3a9388 (patch)
treefa75f573912b3647dcc7158961aa1085e572a8a1 /pkg/tcpip/transport
parent9dec7a3db99d8c7045324bc6d8f0c27e88407f6c (diff)
Add multicast support
PiperOrigin-RevId: 212750821 Change-Id: I822fd63e48c684b45fd91f9ce057867b7eceb792
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/ping/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/connect.go58
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go123
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go235
5 files changed, 315 insertions, 108 deletions
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go
index fc98c41eb..7aaf2d9c6 100644
--- a/pkg/tcpip/transport/ping/endpoint.go
+++ b/pkg/tcpip/transport/ping/endpoint.go
@@ -385,7 +385,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
vv := buffer.NewVectorisedView(len(data), []buffer.View{data})
- return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber)
+ return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
@@ -408,8 +408,9 @@ func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv6.SetChecksum(0)
icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+
vv := buffer.NewVectorisedView(len(data), []buffer.View{data})
- return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber)
+ return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL())
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index de5f963cf..ce87d5818 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -563,14 +563,14 @@ func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, a
}
options := makeSynOptions(opts)
- err := sendTCPWithOptions(r, id, buffer.VectorisedView{}, flags, seq, ack, rcvWnd, options)
+ err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
putOptions(options)
return err
}
-// sendTCPWithOptions sends a TCP segment with the provided options via the
-// provided network endpoint and under the provided identity.
-func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
@@ -608,48 +608,7 @@ func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffe
r.Stats().TCP.ResetsSent.Increment()
}
- return r.WritePacket(&hdr, data, ProtocolNumber)
-}
-
-// sendTCP sends a TCP segment via the provided network endpoint and under the
-// provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, payload buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()))
-
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
- }
-
- // Initialize the header.
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
- tcp.Encode(&header.TCPFields{
- SrcPort: id.LocalPort,
- DstPort: id.RemotePort,
- SeqNum: uint32(seq),
- AckNum: uint32(ack),
- DataOffset: header.TCPMinimumSize,
- Flags: flags,
- WindowSize: uint16(rcvWnd),
- })
-
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
- length := uint16(hdr.UsedLength() + payload.Size())
- xsum := r.PseudoHeaderChecksum(ProtocolNumber)
- for _, v := range payload.Views() {
- xsum = header.Checksum(v, xsum)
- }
-
- tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
- }
-
- r.Stats().TCP.SegmentsSent.Increment()
- if (flags & flagRst) != 0 {
- r.Stats().TCP.ResetsSent.Increment()
- }
-
- return r.WritePacket(&hdr, payload, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
// makeOptions makes an options slice.
@@ -698,12 +657,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- if len(options) > 0 {
- err := sendTCPWithOptions(&e.route, e.id, data, flags, seq, ack, rcvWnd, options)
- putOptions(options)
- return err
- }
- err := sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd)
+ err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options)
putOptions(options)
return err
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 006b2f074..fe21f2c78 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -147,7 +147,7 @@ func replyWithReset(s *segment) {
ack := s.sequenceNumber.Add(s.logicalLen())
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, flagRst|flagAck, seq, ack, 0)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0, nil)
}
// SetOption implements TransportProtocol.SetOption.
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index d091a6196..5de518a55 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -70,19 +70,24 @@ type endpoint struct {
rcvTimestamp bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- id stack.TransportEndpointID
- state endpointState
- bindNICID tcpip.NICID
- regNICID tcpip.NICID
- route stack.Route `state:"manual"`
- dstPort uint16
- v6only bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+ multicastTTL uint8
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
+ multicastMemberships []multicastMembership
+
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
@@ -92,11 +97,29 @@ type endpoint struct {
effectiveNetProtos []tcpip.NetworkProtocolNumber
}
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
- stack: stack,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
@@ -135,6 +158,11 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -329,8 +357,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
return 0, err
}
+ ttl := route.DefaultTTL()
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ ttl = e.multicastTTL
+ }
+
vv := buffer.NewVectorisedView(len(v), []buffer.View{v})
- if err := sendUDP(route, vv, e.id.LocalPort, dstPort); err != nil {
+ if err := sendUDP(route, vv, e.id.LocalPort, dstPort, ttl); err != nil {
return 0, err
}
return uintptr(len(v)), nil
@@ -365,6 +398,56 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.rcvMu.Lock()
e.rcvTimestamp = v != 0
e.rcvMu.Unlock()
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.multicastTTL = uint8(v)
+
+ case tcpip.AddMembershipOption:
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrNoRoute
+ }
+
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+
+ case tcpip.RemoveMembershipOption:
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrNoRoute
+ }
+
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for i, mem := range e.multicastMemberships {
+ if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
+ // Only remove the first match, so that each added membership above is
+ // paired with exactly 1 removal.
+ e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ break
+ }
+ }
}
return nil
}
@@ -421,6 +504,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 1
}
e.rcvMu.Unlock()
+
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
}
return tcpip.ErrUnknownProtocolOption
@@ -428,7 +517,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -454,7 +543,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
- return r.WritePacket(&hdr, data, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
@@ -581,7 +670,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != stateBound && e.state != stateConnected {
return tcpip.ErrNotConnected
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4700193c2..6d7a737bd 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -34,16 +34,20 @@ import (
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
+ testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
+ multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+ multicastAddr = "\xe8\x2b\xd3\xea"
+ multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ multicastPort = 1234
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -128,37 +132,35 @@ func (c *testContext) createV6Endpoint(v6only bool) {
}
}
-func (c *testContext) getV6Packet() []byte {
+func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
select {
case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ if p.Proto != protocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr))
- return b
-
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
- }
-
- return nil
-}
-
-func (c *testContext) getPacket() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
+ var srcAddr, dstAddr tcpip.Address
+ switch protocolNumber {
+ case ipv4.ProtocolNumber:
+ checkerFn = checker.IPv4
+ srcAddr, dstAddr = stackAddr, testAddr
+ if multicast {
+ dstAddr = multicastAddr
+ }
+ case ipv6.ProtocolNumber:
+ checkerFn = checker.IPv6
+ srcAddr, dstAddr = stackV6Addr, testV6Addr
+ if multicast {
+ dstAddr = multicastV6Addr
+ }
+ default:
+ c.t.Fatalf("unknown protocol %d", protocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
-
- checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr))
+ checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
return b
case <-time.After(2 * time.Second):
@@ -495,7 +497,7 @@ func testV4Write(c *testContext) uint16 {
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -525,7 +527,7 @@ func testV6Write(c *testContext) uint16 {
}
// Check that we received the packet.
- b := c.getV6Packet()
+ b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
@@ -682,7 +684,7 @@ func TestV6WriteOnConnected(t *testing.T) {
}
// Check that we received the packet.
- b := c.getV6Packet()
+ b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
@@ -718,7 +720,7 @@ func TestV4WriteOnConnected(t *testing.T) {
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -769,3 +771,162 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
}
}
+
+func TestTTL(t *testing.T) {
+ payload := tcpip.SlicePayload(buffer.View(newPayload()))
+
+ for _, name := range []string{"v4", "v6", "dual"} {
+ t.Run(name, func(t *testing.T) {
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch name {
+ case "v4":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6", "dual":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var variants []string
+ switch name {
+ case "v4":
+ variants = []string{"v4"}
+ case "v6":
+ variants = []string{"v6"}
+ case "dual":
+ variants = []string{"v6", "mapped"}
+ }
+
+ for _, variant := range variants {
+ t.Run(variant, func(t *testing.T) {
+ for _, typ := range []string{"unicast", "multicast"} {
+ t.Run(typ, func(t *testing.T) {
+ var addr tcpip.Address
+ var port uint16
+ switch typ {
+ case "unicast":
+ port = testPort
+ switch variant {
+ case "v4":
+ addr = testAddr
+ case "mapped":
+ addr = testV4MappedAddr
+ case "v6":
+ addr = testV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ port = multicastPort
+ switch variant {
+ case "v4":
+ addr = multicastAddr
+ case "mapped":
+ addr = multicastV4MappedAddr
+ case "v6":
+ addr = multicastV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ switch name {
+ case "v4":
+ case "v6":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ case "dual":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
+ }
+
+ checkerFn := checker.IPv4
+ switch variant {
+ case "v4", "mapped":
+ case "v6":
+ checkerFn = checker.IPv6
+ default:
+ t.Fatal("unknown test variant")
+ }
+ var wantTTL uint8
+ var multicast bool
+ switch typ {
+ case "unicast":
+ multicast = false
+ switch variant {
+ case "v4", "mapped":
+ ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ case "v6":
+ ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ wantTTL = multicastTTL
+ multicast = true
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch variant {
+ case "v4", "mapped":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ b := c.getPacket(networkProtocolNumber, multicast)
+ checkerFn(c.t, b,
+ checker.TTL(wantTTL),
+ checker.UDP(
+ checker.DstPort(port),
+ ),
+ )
+ })
+ }
+ })
+ }
+ })
+ }
+}