diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/ping/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 58 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 123 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 235 |
5 files changed, 315 insertions, 108 deletions
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index fc98c41eb..7aaf2d9c6 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -385,7 +385,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber) + return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()) } func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { @@ -408,8 +408,9 @@ func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber) + return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index de5f963cf..ce87d5818 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -563,14 +563,14 @@ func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, a } options := makeSynOptions(opts) - err := sendTCPWithOptions(r, id, buffer.VectorisedView{}, flags, seq, ack, rcvWnd, options) + err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options) putOptions(options) return err } -// sendTCPWithOptions sends a TCP segment with the provided options via the -// provided network endpoint and under the provided identity. -func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { optLen := len(opts) // Allocate a buffer for the TCP header. hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) @@ -608,48 +608,7 @@ func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffe r.Stats().TCP.ResetsSent.Increment() } - return r.WritePacket(&hdr, data, ProtocolNumber) -} - -// sendTCP sends a TCP segment via the provided network endpoint and under the -// provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, payload buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength())) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff - } - - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) - tcp.Encode(&header.TCPFields{ - SrcPort: id.LocalPort, - DstPort: id.RemotePort, - SeqNum: uint32(seq), - AckNum: uint32(ack), - DataOffset: header.TCPMinimumSize, - Flags: flags, - WindowSize: uint16(rcvWnd), - }) - - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { - length := uint16(hdr.UsedLength() + payload.Size()) - xsum := r.PseudoHeaderChecksum(ProtocolNumber) - for _, v := range payload.Views() { - xsum = header.Checksum(v, xsum) - } - - tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) - } - - r.Stats().TCP.SegmentsSent.Increment() - if (flags & flagRst) != 0 { - r.Stats().TCP.ResetsSent.Increment() - } - - return r.WritePacket(&hdr, payload, ProtocolNumber) + return r.WritePacket(&hdr, data, ProtocolNumber, ttl) } // makeOptions makes an options slice. @@ -698,12 +657,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - if len(options) > 0 { - err := sendTCPWithOptions(&e.route, e.id, data, flags, seq, ack, rcvWnd, options) - putOptions(options) - return err - } - err := sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd) + err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options) putOptions(options) return err } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 006b2f074..fe21f2c78 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -147,7 +147,7 @@ func replyWithReset(s *segment) { ack := s.sequenceNumber.Add(s.logicalLen()) - sendTCP(&s.route, s.id, buffer.VectorisedView{}, flagRst|flagAck, seq, ack, 0) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0, nil) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d091a6196..5de518a55 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -70,19 +70,24 @@ type endpoint struct { rcvTimestamp bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - id stack.TransportEndpointID - state endpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID - route stack.Route `state:"manual"` - dstPort uint16 - v6only bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + regNICID tcpip.NICID + route stack.Route `state:"manual"` + dstPort uint16 + v6only bool + multicastTTL uint8 // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags + // multicastMemberships that need to be remvoed when the endpoint is + // closed. Protected by the mu mutex. + multicastMemberships []multicastMembership + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple @@ -92,11 +97,29 @@ type endpoint struct { effectiveNetProtos []tcpip.NetworkProtocolNumber } +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { return &endpoint{ - stack: stack, - netProto: netProto, - waiterQueue: waiterQueue, + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + // RFC 1075 section 5.4 recommends a TTL of 1 for membership + // requests. + // + // RFC 5135 4.2.1 appears to assume that IGMP messages have a + // TTL of 1. + // + // RFC 5135 Appendix A defines TTL=1: A multicast source that + // wants its traffic to not traverse a router (e.g., leave a + // home network) may find it useful to send traffic with IP + // TTL=1. + // + // Linux defaults to TTL=1. + multicastTTL: 1, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, } @@ -135,6 +158,11 @@ func (e *endpoint) Close() { e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } + for _, mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + // Close the receive list and drain it. e.rcvMu.Lock() e.rcvClosed = true @@ -329,8 +357,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc return 0, err } + ttl := route.DefaultTTL() + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { + ttl = e.multicastTTL + } + vv := buffer.NewVectorisedView(len(v), []buffer.View{v}) - if err := sendUDP(route, vv, e.id.LocalPort, dstPort); err != nil { + if err := sendUDP(route, vv, e.id.LocalPort, dstPort, ttl); err != nil { return 0, err } return uintptr(len(v)), nil @@ -365,6 +398,56 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.rcvMu.Lock() e.rcvTimestamp = v != 0 e.rcvMu.Unlock() + + case tcpip.MulticastTTLOption: + e.mu.Lock() + defer e.mu.Unlock() + e.multicastTTL = uint8(v) + + case tcpip.AddMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + + e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) + + case tcpip.RemoveMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + for i, mem := range e.multicastMemberships { + if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr { + // Only remove the first match, so that each added membership above is + // paired with exactly 1 removal. + e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + break + } + } } return nil } @@ -421,6 +504,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 1 } e.rcvMu.Unlock() + + case *tcpip.MulticastTTLOption: + e.mu.Lock() + *o = tcpip.MulticastTTLOption(e.multicastTTL) + e.mu.Unlock() + return nil } return tcpip.ErrUnknownProtocolOption @@ -428,7 +517,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -454,7 +543,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - return r.WritePacket(&hdr, data, ProtocolNumber) + return r.WritePacket(&hdr, data, ProtocolNumber, ttl) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { @@ -581,7 +670,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.state != stateConnected { + // A socket in the bound state can still receive multicast messages, + // so we need to notify waiters on shutdown. + if e.state != stateBound && e.state != stateConnected { return tcpip.ErrNotConnected } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4700193c2..6d7a737bd 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -34,16 +34,20 @@ import ( ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr + testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr + multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testAddr = "\x0a\x00\x00\x02" + testPort = 4096 + multicastAddr = "\xe8\x2b\xd3\xea" + multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + multicastPort = 1234 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -128,37 +132,35 @@ func (c *testContext) createV6Endpoint(v6only bool) { } } -func (c *testContext) getV6Packet() []byte { +func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { select { case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + if p.Proto != protocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -func (c *testContext) getPacket() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) + var srcAddr, dstAddr tcpip.Address + switch protocolNumber { + case ipv4.ProtocolNumber: + checkerFn = checker.IPv4 + srcAddr, dstAddr = stackAddr, testAddr + if multicast { + dstAddr = multicastAddr + } + case ipv6.ProtocolNumber: + checkerFn = checker.IPv6 + srcAddr, dstAddr = stackV6Addr, testV6Addr + if multicast { + dstAddr = multicastV6Addr + } + default: + c.t.Fatalf("unknown protocol %d", protocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr)) + checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) return b case <-time.After(2 * time.Second): @@ -495,7 +497,7 @@ func testV4Write(c *testContext) uint16 { } // Check that we received the packet. - b := c.getPacket() + b := c.getPacket(ipv4.ProtocolNumber, false) udp := header.UDP(header.IPv4(b).Payload()) checker.IPv4(c.t, b, checker.UDP( @@ -525,7 +527,7 @@ func testV6Write(c *testContext) uint16 { } // Check that we received the packet. - b := c.getV6Packet() + b := c.getPacket(ipv6.ProtocolNumber, false) udp := header.UDP(header.IPv6(b).Payload()) checker.IPv6(c.t, b, checker.UDP( @@ -682,7 +684,7 @@ func TestV6WriteOnConnected(t *testing.T) { } // Check that we received the packet. - b := c.getV6Packet() + b := c.getPacket(ipv6.ProtocolNumber, false) udp := header.UDP(header.IPv6(b).Payload()) checker.IPv6(c.t, b, checker.UDP( @@ -718,7 +720,7 @@ func TestV4WriteOnConnected(t *testing.T) { } // Check that we received the packet. - b := c.getPacket() + b := c.getPacket(ipv4.ProtocolNumber, false) udp := header.UDP(header.IPv4(b).Payload()) checker.IPv4(c.t, b, checker.UDP( @@ -769,3 +771,162 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) } } + +func TestTTL(t *testing.T) { + payload := tcpip.SlicePayload(buffer.View(newPayload())) + + for _, name := range []string{"v4", "v6", "dual"} { + t.Run(name, func(t *testing.T) { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch name { + case "v4": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + var variants []string + switch name { + case "v4": + variants = []string{"v4"} + case "v6": + variants = []string{"v6"} + case "dual": + variants = []string{"v6", "mapped"} + } + + for _, variant := range variants { + t.Run(variant, func(t *testing.T) { + for _, typ := range []string{"unicast", "multicast"} { + t.Run(typ, func(t *testing.T) { + var addr tcpip.Address + var port uint16 + switch typ { + case "unicast": + port = testPort + switch variant { + case "v4": + addr = testAddr + case "mapped": + addr = testV4MappedAddr + case "v6": + addr = testV6Addr + default: + t.Fatal("unknown test variant") + } + case "multicast": + port = multicastPort + switch variant { + case "v4": + addr = multicastAddr + case "mapped": + addr = multicastV4MappedAddr + case "v6": + addr = multicastV6Addr + default: + t.Fatal("unknown test variant") + } + default: + t.Fatal("unknown test variant") + } + + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + switch name { + case "v4": + case "v6": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + case "dual": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + default: + t.Fatal("unknown test variant") + } + + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) + } + + checkerFn := checker.IPv4 + switch variant { + case "v4", "mapped": + case "v6": + checkerFn = checker.IPv6 + default: + t.Fatal("unknown test variant") + } + var wantTTL uint8 + var multicast bool + switch typ { + case "unicast": + multicast = false + switch variant { + case "v4", "mapped": + ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + case "v6": + ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + default: + t.Fatal("unknown test variant") + } + case "multicast": + wantTTL = multicastTTL + multicast = true + default: + t.Fatal("unknown test variant") + } + + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch variant { + case "v4", "mapped": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + b := c.getPacket(networkProtocolNumber, multicast) + checkerFn(c.t, b, + checker.TTL(wantTTL), + checker.UDP( + checker.DstPort(port), + ), + ) + }) + } + }) + } + }) + } +} |