diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2020-11-05 15:49:51 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-11-05 15:52:16 -0800 |
commit | 8c0701462a84ff77e602f1626aec49479c308127 (patch) | |
tree | adab5dade34c76acdf66ffeb675a3d356a97891a /pkg | |
parent | 7caefd68df06062d2c0a3547132f1d25af49af22 (diff) |
Use stack.Route exclusively for writing packets
* Remove stack.Route from incoming packet path.
There is no need to pass around a stack.Route during the incoming path
of a packet. Instead, pass around the packet's link/network layer
information in the packet buffer since all layers may need this
information.
* Support address bound and outgoing packet NIC in routes.
When forwarding is enabled, the source address of a packet may be bound
to a different interface than the outgoing interface. This change
updates stack.Route to hold both NICs so that one can be used to write
packets while the other is used to check if the route's bound address
is valid. Note, we need to hold the address's interface so we can check
if the address is a spoofed address.
* Introduce the concept of a local route.
Local routes are routes where the packet never needs to leave the stack;
the destination is stack-local. We can now route between interfaces
within a stack if the packet never needs to leave the stack, even when
forwarding is disabled.
* Always obtain a route from the stack before sending a packet.
If a packet needs to be sent in response to an incoming packet, a route
must be obtained from the stack to ensure the stack is configured to
send packets to the packet's source from the packet's destination.
* Enable spoofing if a stack may send packets from unowned addresses.
This change required changes to some netgophers since previously,
promiscuous mode was enough to let the netstack respond to all
incoming packets regardless of the packet's destination address. Now
that a stack.Route is not held for each incoming packet, finding a route
may fail with local addresses we don't own but accepted packets for
while in promiscuous mode. Since we also want to be able to send from
any address (in response the received promiscuous mode packets), we need
to enable spoofing.
* Skip transport layer checksum checks for locally generated packets.
If a packet is locally generated, the stack can safely assume that no
errors were introduced while being locally routed since the packet is
never sent out the wire.
Some bugs fixed:
- transport layer checksum was never calculated after NAT.
- handleLocal didn't handle routing across interfaces.
- stack didn't support forwarding across interfaces.
- always consult the routing table before creating an endpoint.
Updates #4688
Fixes #3906
PiperOrigin-RevId: 340943442
Diffstat (limited to 'pkg')
41 files changed, 2085 insertions, 545 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index a79379abb..33a4a0720 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -122,7 +122,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } @@ -145,7 +145,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } @@ -158,6 +158,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) + respPkt.NetworkProtocolNumber = ProtocolNumber packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 969579601..8873bd91f 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -110,8 +110,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { + netHdr := pkt.Network() + t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) t.dataCalls++ return stack.TransportPacketHandled } @@ -608,7 +609,8 @@ func TestIPv4Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -707,7 +709,9 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -788,7 +792,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } @@ -800,7 +805,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -900,7 +906,8 @@ func TestIPv6Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -1017,7 +1024,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -1071,7 +1080,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum tcpip.NetworkProtocolNumber nicAddr tcpip.Address remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.View + pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) expectedErr *tcpip.Error }{ @@ -1081,7 +1090,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1095,7 +1104,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1123,7 +1132,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1137,7 +1146,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1147,7 +1156,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1156,7 +1165,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1166,7 +1175,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1175,7 +1184,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1203,7 +1212,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) @@ -1221,7 +1230,49 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) } - return hdr.View() + return hdr.View().ToVectorisedView() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with options and data across views", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(header.IPv4MinimumSize + len(ipv4Options)), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + vv := buffer.View(ip).ToVectorisedView() + vv.AppendView(ipv4Options) + vv.AppendView(data) + return vv }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1251,7 +1302,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1264,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1291,7 +1342,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1307,7 +1358,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1334,7 +1385,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1342,7 +1393,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1369,7 +1420,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1377,7 +1428,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1421,7 +1472,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { defer r.Release() if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + Data: test.pktGen(t, subTest.srcAddr), })); err != test.expectedErr { t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index cf287446e..9b5e37fee 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -42,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. - src := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { + srcAddr := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 { return } @@ -58,11 +58,11 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { - stats := r.Stats() +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() received := stats.ICMP.V4PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a @@ -83,7 +83,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // packets with checksum errors. switch h.Type() { case header.ICMPv4Echo: - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) } return } @@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := processIPOptions(r, iph.Options(), op) + aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op) if err != nil { switch { case @@ -116,9 +116,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidLength), errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() } return } @@ -131,7 +131,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Echo.Increment() sent := stats.ICMP.V4PacketsSent - if !r.Stack().AllowICMPMessage() { + if !e.protocol.stack.AllowICMPMessage() { sent.RateLimited.Increment() return } @@ -144,10 +144,13 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. replyData := pkt.Data.ToOwnedView() + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast // It's possible that a raw socket expects to receive this. - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) pkt = nil + // Take the base of the incoming request IP header but replace the options. replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) @@ -156,12 +159,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). - localAddr := r.LocalAddress - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { + localAddr := ipHdr.DestinationAddress() + if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -218,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() @@ -307,7 +310,11 @@ func (*icmpReasonParamProblem) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv4(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -331,8 +338,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in // response to a non-initial fragment, but it currently can not happen. - - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any { return nil } @@ -340,14 +346,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil sent := p.stack.Stats().ICMP.V4PacketsSent if !p.stack.AllowICMPMessage() { @@ -355,11 +358,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. - if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { // TODO(gvisor.dev/issue/3810): // Unfortunately the current stack pretty much always has ICMPv4 headers // in the Data section of the packet but there is no guarantee that is the @@ -416,7 +418,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -428,7 +430,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // view with the entire incoming IP packet reassembled and truncated as // required. This is now the payload of the new ICMP packet and no longer // considered a packet in its own right. - newHeader := append(buffer.View(nil), networkHeader...) + newHeader := append(buffer.View(nil), origIPHdr...) newHeader = append(newHeader, transportHeader...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 4592984a5..1bc2c4aff 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -252,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() return nil @@ -270,16 +269,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, pkt) - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil @@ -373,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -403,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return tcpip.ErrMalformedHeader } + + hdrLen := header.IPv4(h).HeaderLength() + if hdrLen < header.IPv4MinimumSize { + return tcpip.ErrMalformedHeader + } + + h, ok = pkt.Data.PullUp(int(hdrLen)) + if !ok { + return tcpip.ErrMalformedHeader + } ip := header.IPv4(h) // Always set the total length. @@ -447,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -480,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -488,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -498,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -506,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -520,8 +545,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } @@ -537,12 +562,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { var releaseCB func(bool) if start == 0 { pkt := pkt.Clone() - r := r.Clone() releaseCB = func(timedOut bool) { if timedOut { - _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) } - r.Release() } } @@ -566,8 +589,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if !ready { @@ -579,7 +602,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } - r.Stats().IP.PacketsDelivered.Increment() + stats.IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { @@ -587,14 +610,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // headers, the setting of the transport number here should be // unnecessary and removed. pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt) + e.handleICMP(pkt) return } if len(h.Options()) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{}) + aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{}) if err != nil { switch { case @@ -604,15 +627,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidLength), errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() } return } } - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination @@ -620,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -919,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head originalIPHeaderLength := len(originalIPHeader) nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) @@ -1172,8 +1196,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // - The location of an error if there was one (or 0 if no error) // - If there is an error, information as to what it was was. // - The replacement option set. -func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { - +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + stats := e.protocol.stack.Stats() opts := header.IPv4Options(orig) optIter := opts.MakeIterator() @@ -1186,13 +1210,15 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag // This will need tweaking when we start really forwarding packets // as we may need to get two addresses, for rx and tx interfaces. // We will also have to take usage into account. - prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber) + prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) localAddress := prefixedAddress.Address if err != nil { - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + h := header.IPv4(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress } - localAddress = r.LocalAddress + localAddress = dstAddr } for { @@ -1219,9 +1245,9 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag optLen := int(option.Size()) switch option := option.(type) { case *header.IPv4OptionTimestamp: - r.Stats().IP.OptionTSReceived.Increment() + stats.IP.OptionTSReceived.Increment() if usage.actions().timestamp != optionRemove { - clock := r.Stack().Clock() + clock := e.protocol.stack.Clock() newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) @@ -1232,7 +1258,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag } case *header.IPv4OptionRecordRoute: - r.Stats().IP.OptionRRReceived.Increment() + stats.IP.OptionRRReceived.Increment() if usage.actions().recordRoute != optionRemove { newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) @@ -1244,7 +1270,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag } default: - r.Stats().IP.OptionUnknownReceived.Increment() + stats.IP.OptionUnknownReceived.Increment() if usage.actions().unknown == optionPass { newBuffer := optIter.RemainingBuffer()[:optLen] // Arguments already heavily checked.. ignore result. diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 3c15e41a7..8502b848c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -124,8 +124,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { }) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := r.Stats().ICMP +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { + stats := e.protocol.stack.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their @@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) + srcAddr := iph.SourceAddress() + dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. // // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) - if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { received.Invalid.Increment() return } @@ -224,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // we know we are also performing DAD on it). In this case we let the // stack know so it can handle such a scenario and do nothing further with // the NS. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { // We would get an error if the address no longer exists or the address // is no longer tentative (DAD resolved between the call to // hasTentativeAddr and this point). Both of these are valid scenarios: @@ -251,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -277,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. - unspecifiedSource := r.RemoteAddress == header.IPv6Any + unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { - if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { received.Invalid.Increment() return } @@ -287,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -298,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // ... // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. - if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { received.Invalid.Increment() return } @@ -308,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the source of the solicitation is the unspecified address, the node // MUST [...] and multicast the advertisement to the all-nodes address. // - remoteAddr := r.RemoteAddress + remoteAddr := srcAddr if unspecifiedSource { remoteAddr = header.IPv6AllNodesMulticastAddress } @@ -465,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. - localAddr := r.LocalAddress - if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr := dstAddr + if header.IsV6MulticastAddress(dstAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -486,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -498,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -519,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - stack := r.Stack() + stack := e.protocol.stack // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { @@ -550,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { received.Invalid.Increment() return } @@ -558,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } } @@ -575,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - routerAddr := iph.SourceAddress() + routerAddr := srcAddr // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { @@ -608,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } e.mu.Lock() @@ -753,7 +759,11 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv6(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // @@ -780,7 +790,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { return nil } @@ -788,14 +798,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil stats := p.stack.Stats().ICMP sent := stats.V6PacketsSent diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index aa8b5f2e5..76013daa1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -87,7 +87,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { +func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { return stack.TransportPacketHandled } @@ -282,7 +282,8 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -424,7 +425,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -1796,7 +1798,8 @@ func TestCallsToNeighborCache(t *testing.T) { SrcAddr: r.RemoteAddress, DstAddr: r.LocalAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. if nudHandler.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 1e38f3a9d..68ad35bfe 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -465,21 +465,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ - // The inbound path expects an unparsed packet. - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) - - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil @@ -576,10 +582,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -637,22 +645,27 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() // As per RFC 4291 section 2.7: // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. - if header.IsV6MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if header.IsV6MulticastAddress(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -671,7 +684,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -681,7 +694,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -693,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: previousHeaderStart, }, pkt) @@ -705,7 +718,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -719,7 +732,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -732,7 +745,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -757,7 +770,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // header, so we just make sure Segments Left is zero before processing // the next extension header. if extHdr.SegmentsLeft() != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), }, pkt) @@ -794,8 +807,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if done { @@ -822,8 +835,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } } @@ -831,8 +844,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } @@ -845,9 +858,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // of the fragment, pointing to the Payload Length field of the // fragment packet. if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, }, pkt) @@ -866,9 +879,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // the fragment, pointing to the Fragment Offset field of the fragment // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, }, pkt) @@ -880,12 +893,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { var releaseCB func(bool) if start == 0 { pkt := pkt.Clone() - r := r.Clone() releaseCB = func(timedOut bool) { if timedOut { - _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) } - r.Release() } } @@ -895,8 +906,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ - Source: h.SourceAddress(), - Destination: h.DestinationAddress(), + Source: srcAddr, + Destination: dstAddr, ID: extHdr.ID(), }, start, @@ -907,8 +918,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } pkt.Data = data @@ -927,7 +938,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -941,7 +952,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -954,7 +965,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -977,13 +988,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - r.Stats().IP.PacketsDelivered.Increment() + stats.IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt, hasFragmentHeader) + e.handleICMP(pkt, hasFragmentHeader) } else { - r.Stats().IP.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + stats.IP.PacketsDelivered.Increment() + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -991,7 +1002,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -1012,7 +1023,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -1022,11 +1033,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) - r.Stats().UnknownProtocolRcvdPackets.Increment() + stats.UnknownProtocolRcvdPackets.Increment() return } } @@ -1635,6 +1646,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea originalIPHeadersLength := len(originalIPHeaders) fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 7f2ebc0cb..981d1371a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -573,6 +573,13 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) @@ -993,7 +1000,8 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } var tllData [header.NDPLinkLayerAddressSize]byte diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 4f4065f48..9a17efcba 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + length := uint16(len(tcpHeader) + pkt.Data.Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) - } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumVV(pkt.Data, xsum) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 380688038..7a501acdc 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index ff55ef1a3..d63e9757c 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -146,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) + netHeader := pkt.Network() + netHeader.SetDestinationAddress(address) // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(protocol, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumVV(pkt.Data, xsum) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - pkt.Network().SetDestinationAddress(address) - // After modification, IPv4 packets need a valid checksum. if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { netHeader := header.IPv4(pkt.NetworkHeader().View()) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ff02c7c65..60c81a3aa 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -348,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } +func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + return true + } + + return false +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) @@ -555,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { } func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) defer r.Release() - r.RemoteLinkAddress = remotelinkAddr - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + n.getNetworkEndpoint(protocol).HandlePacket(pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -594,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if local == "" { local = n.LinkEndpoint.LinkAddress() } + pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] @@ -669,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.nic + n := r.localAddressNIC if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() - r.RemoteLinkAddress = remote + pkt.NICID = n.ID() r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + pkt.NetworkPacketInfo = r.networkPacketInfo() + n.getNetworkEndpoint(protocol).HandlePacket(pkt) addressEndpoint.DecRef() r.Release() return @@ -735,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -747,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, pkt) + n.stack.demux.deliverRawPacket(protocol, pkt) // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. @@ -776,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN return TransportPacketHandled } - id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, pkt, id) { + netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] + if !ok { + panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers())) + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + id := TransportEndpointID{ + LocalPort: dstPort, + LocalAddress: dst, + RemotePort: srcPort, + RemoteAddress: src, + } + if n.stack.demux.deliverPacket(protocol, pkt, id) { return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, pkt) { + if state.defaultHandler(id, pkt) { return TransportPacketHandled } } @@ -791,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet so // give the protocol specific error handler a chance to handle it. // If it doesn't handle it then we should do so. - switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() return TransportPacketHandled diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 4af04846f..5b5c58afb 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip } // HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { -} +func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} // Close implements NetworkEndpoint.Close. func (e *testIPv6Endpoint) Close() { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7f54a6de8..664cc6fa0 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -112,6 +112,16 @@ type PacketBuffer struct { // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType + + // NICID is the ID of the interface the network packet was received at. + NICID tcpip.NICID + + // RXTransportChecksumValidated indicates that transport checksum verification + // may be safely skipped. + RXTransportChecksumValidated bool + + // NetworkPacketInfo holds an incoming packet's network-layer information. + NetworkPacketInfo NetworkPacketInfo } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // Clone should be called in such cases so that no modifications is done to // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, - TransportProtocolNumber: pk.TransportProtocolNumber, + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, + PktType: pk.PktType, + NICID: pk.NICID, + RXTransportChecksumValidated: pk.RXTransportChecksumValidated, + NetworkPacketInfo: pk.NetworkPacketInfo, } - return newPk +} + +// SourceLinkAddress returns the source link address of the packet. +func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { + link := pk.LinkHeader().View() + + if link.IsEmpty() { + return "" + } + + return header.Ethernet(link).SourceAddress() } // Network returns the network header as a header.Network. @@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network { } } +// CloneToInbound makes a shallow copy of the packet buffer to be used as an +// inbound packet. +// +// See PacketBuffer.Data for details about how a packet buffer holds an inbound +// packet. +func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { + return NewPacketBuffer(PacketBufferOptions{ + Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), + }) +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index f838eda8d..5d364a2b0 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } else if _, err := p.route.Resolve(nil); err != nil { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 203f3b51f..72131ca24 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -63,17 +63,28 @@ const ( ControlUnknown ) +// NetworkPacketInfo holds information about a network layer packet. +type NetworkPacketInfo struct { + // RemoteAddressBroadcast is true if the packet's remote address is a + // broadcast address. + RemoteAddressBroadcast bool + + // LocalAddressBroadcast is true if the packet's local address is a broadcast + // address. + LocalAddressBroadcast bool +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { // UniqueID returns an unique ID for this transport endpoint. UniqueID() uint64 - // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. It sets pkt.TransportHeader. + // HandlePacket is called by the stack when new packets arrive to this + // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(TransportEndpointID, *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. @@ -105,8 +116,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(*PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -172,9 +183,9 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // HandleUnknownDestinationPacket takes ownership of the packet if it handles // the issue. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition + HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -227,8 +238,8 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition + // DeliverTransportPacket takes ownership of the packet. + DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -547,7 +558,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 87f7008f7..2e698f92f 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,11 +47,16 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - // nic is the NIC the route goes through. - nic *NIC + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint - // addressEndpoint is the local address this route is associated with. - addressEndpoint AssignableAddressEndpoint + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC // linkCache is set if link address resolution is enabled for this protocol on // the route's NIC. @@ -60,51 +67,144 @@ type Route struct { linkRes LinkAddressResolver } +// constructAndValidateRoute validates and initializes a route. It takes +// ownership of the provided local address. +// +// Returns an empty route if validation fails. +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { + addrWithPrefix := addressEndpoint.AddressWithPrefix() + + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + addressEndpoint.DecRef() + return Route{} + } + + // If no remote address is provided, use the local address. + if len(remoteAddr) == 0 { + remoteAddr = addrWithPrefix.Address + } + + r := makeRoute( + netProto, + addrWithPrefix.Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + addressEndpoint, + handleLocal, + multicastLoop, + ) + + // If the route requires us to send a packet through some gateway, do not + // broadcast it. + if len(gateway) > 0 { + r.NextHop = gateway + } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.RemoteLinkAddress = header.EthernetBroadcastAddress + } + + return r +} + // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { + if localAddressNIC.stack != outgoingNIC.stack { + panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) + } + loop := PacketOut - if handleLocal && localAddr != "" && remoteAddr == localAddr { - loop = PacketLoop - } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { - loop |= PacketLoop - } else if remoteAddr == header.IPv4Broadcast { - loop |= PacketLoop + + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if !outgoingNIC.IsLoopback() { + if handleLocal && localAddr != "" && remoteAddr == localAddr { + loop = PacketLoop + } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { + loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop + } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + loop |= PacketLoop + } } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - addressEndpoint: addressEndpoint, - nic: nic, - Loop: loop, + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + localAddressEndpoint: localAddressEndpoint, + outgoingNIC: outgoingNIC, + Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = r.outgoingNIC.stack } } return r } +// makeLocalRoute initializes a new local route. It takes ownership of the +// provided AssignableAddressEndpoint. +// +// A local route is a route to a destination that is local to the stack. +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { + loop := PacketLoop + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if outgoingNIC.IsLoopback() { + loop = PacketOut + } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +// PopulatePacketInfo populates a packet buffer's packet information fields. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { + if r.local() { + pkt.RXTransportChecksumValidated = true + } + pkt.NetworkPacketInfo = r.networkPacketInfo() +} + +// networkPacketInfo returns the network packet information of the route. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) networkPacketInfo() NetworkPacketInfo { + return NetworkPacketInfo{ + RemoteAddressBroadcast: r.IsOutboundBroadcast(), + LocalAddressBroadcast: r.isInboundBroadcast(), + } +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.nic.ID() + return r.outgoingNIC.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.nic.stack.Stats() + return r.outgoingNIC.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen) } -// Capabilities returns the link-layer capabilities of the route. -func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() +// RequiresTXTransportChecksum returns false if the route does not require +// transport checksums to be populated. +func (r *Route) RequiresTXTransportChecksum() bool { + if r.local() { + return false + } + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0 +} + +// HasSoftwareGSOCapability returns true if the route supports software GSO. +func (r *Route) HasSoftwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 +} + +// HasHardwareGSOCapability returns true if the route supports hardware GSO. +func (r *Route) HasHardwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 +} + +// HasSaveRestoreCapability returns true if the route supports save/restore. +func (r *Route) HasSaveRestoreCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0 +} + +// HasDisconncetOkCapability returns true if the route supports disconnecting. +func (r *Route) HasDisconncetOkCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0 } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + // If specified, the local address used for link address resolution must be an + // address on the outgoing interface. + var linkAddressResolutionRequestLocalAddr tcpip.Address + if r.localAddressNIC == r.outgoingNIC { + linkAddressResolutionRequestLocalAddr = r.LocalAddress + } + + if neigh := r.outgoingNIC.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) if err != nil { return ch, err } @@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) if err != nil { return ch, err } @@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { + if neigh := r.outgoingNIC.neigh; neigh != nil { neigh.removeWaker(nextAddr, waker) return } - r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) +} + +// local returns true if the route is a local route. +func (r *Route) local() bool { + return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() } // IsResolutionRequired returns true if Resolve() must be called to resolve -// the link address before r can be written to. +// the link address before the route can be written to. // -// The NIC r uses must not be locked. +// The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.nic.neigh != nil { - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + return false } - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" + + return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil +} + +func (r *Route) isValidForOutgoing() bool { + if !r.outgoingNIC.Enabled() { + return false + } + + if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + return false + } + + // If the source NIC and outgoing NIC are different, make sure the stack has + // forwarding enabled, or the packet will be handled locally. + if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) { + return false + } + + return true } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.nic.getNetworkEndpoint(r.NetProto).MTU() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.addressEndpoint != nil { - r.addressEndpoint.DecRef() - r.addressEndpoint = nil + if r.localAddressEndpoint != nil { + r.localAddressEndpoint.DecRef() + r.localAddressEndpoint = nil } } // Clone clones the route. func (r *Route) Clone() Route { - if r.addressEndpoint != nil { - _ = r.addressEndpoint.IncRef() + if r.localAddressEndpoint != nil { + if !r.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) + } } return *r } @@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.nic.stack + return r.outgoingNIC.stack } func (r *Route) isV4Broadcast(addr tcpip.Address) bool { @@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + subnet := r.localAddressEndpoint.AddressWithPrefix().Subnet() return subnet.IsBroadcast(addr) } @@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool { return r.isV4Broadcast(r.RemoteAddress) } -// IsInboundBroadcast returns true if the route is for an inbound broadcast +// isInboundBroadcast returns true if the route is for an inbound broadcast // packet. -func (r *Route) IsInboundBroadcast() bool { +func (r *Route) isInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.LocalAddress) } @@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool { // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - addressEndpoint: r.addressEndpoint, - nic: r.nic, - linkCache: r.linkCache, - linkRes: r.linkRes, + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + localAddressEndpoint: r.localAddressEndpoint, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ba0e1a7ec..a23fb97ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "fmt" mathrand "math/rand" "sync/atomic" "time" @@ -52,7 +53,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -759,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1202,59 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } +// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route +// from the specified NIC. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) + if localAddressEndpoint == nil { + return Route{}, false + } + + var outgoingNIC *NIC + // Prefer a local route to the same interface as the local address. + if localAddressNIC.hasAddress(netProto, remoteAddr) { + outgoingNIC = localAddressNIC + } + + // If the remote address isn't owned by the local address's NIC, check all + // NICs. + if outgoingNIC == nil { + for _, nic := range s.nics { + if nic.hasAddress(netProto, remoteAddr) { + outgoingNIC = nic + break + } + } + } + + // If the remote address is not owned by the stack, we can't return a local + // route. + if outgoingNIC == nil { + localAddressEndpoint.DecRef() + return Route{}, false + } + + r := makeLocalRoute( + netProto, + localAddressEndpoint.AddressWithPrefix().Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + localAddressEndpoint, + ) + + if r.IsOutboundBroadcast() { + r.Release() + return Route{}, false + } + + return r, true +} + +// findLocalRouteRLocked returns a local route. +// +// A local route is a route to some remote address which the stack owns. That +// is, a local route is a route where packets never have to leave the stack. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + if len(localAddr) == 0 { + localAddr = remoteAddr + } + + if localAddressNICID == 0 { + for _, localAddressNIC := range s.nics { + if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { + return r, true + } + } + + return Route{}, false + } + + if localAddressNIC, ok := s.nics[localAddressNICID]; ok { + return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) + } + + return Route{}, false +} + // FindRoute creates a route to the given destination address, leaving through -// the given nic and local address (if provided). +// the given NIC and local address (if provided). +// +// If a NIC is not specified, the returned route will leave through the same +// NIC as the NIC that has the local address assigned when forwarding is +// disabled. If forwarding is enabled and the NIC is unspecified, the route may +// leave through any interface unless the route is link-local. +// +// If no local address is provided, the stack will select a local address. If no +// remote address is provided, the stack wil use a remote address equal to the +// local address. func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() + isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) - IsLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) - needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || IsLoopback) + isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) + needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) + + if s.handleLocal && !isMulticast && !isLocalBroadcast { + if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + return r, nil + } + } + + // If the interface is specified and we do not need a route, return a route + // through the interface if the interface is valid and enabled. if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.Enabled() { if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil + return makeRoute( + netProto, + addressEndpoint.AddressWithPrefix().Address, + remoteAddr, + nic, /* outboundNIC */ + nic, /* localAddressNIC*/ + addressEndpoint, + s.handleLocal, + multicastLoop, + ), nil } } - } else { - for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { - continue + + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable + } + + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + + // Find a route to the remote with the route table. + var chosenRoute tcpip.Route + for _, route := range s.routeTable { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } + + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + if r == (Route{}) { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r, nil } - if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if len(remoteAddr) == 0 { - // If no remote address was provided, then the route - // provided will refer to the link local address. - remoteAddr = addressEndpoint.AddressWithPrefix().Address - } + } + + // If the stack has forwarding enabled and we haven't found a valid route to + // the remote address yet, keep track of the first valid route. We keep + // iterating because we prefer routes that let us use a local address that + // is assigned to the outgoing interface. There is no requirement to do this + // from any RFC but simply a choice made to better follow a strong host + // model which the netstack follows at the time of writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } + } + + if chosenRoute != (tcpip.Route{}) { + // At this point we know the stack has forwarding enabled since chosenRoute is + // only set when forwarding is enabled. + nic, ok := s.nics[chosenRoute.NIC] + if !ok { + // If the route's NIC was invalid, we should not have chosen the route. + panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC)) + } + + var gateway tcpip.Address + if needRoute { + gateway = chosenRoute.Gateway + } - r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) - if len(route.Gateway) > 0 { - if needRoute { - r.NextHop = route.Gateway - } - } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + // Use the specified NIC to get the local address endpoint. + if id != 0 { + if aNIC, ok := s.nics[id]; ok { + if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + return r, nil } + } + } + + return Route{}, tcpip.ErrNoRoute + } + if id == 0 { + // If an interface is not specified, try to find a NIC that holds the local + // address endpoint to construct a route. + for _, aNIC := range s.nics { + addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto) + if addressEndpoint == nil { + continue + } + + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } } } - if !needRoute { - if IsLoopback { - return Route{}, tcpip.ErrBadLocalAddress - } - return Route{}, tcpip.ErrNetworkUnreachable + if needRoute { + return Route{}, tcpip.ErrNoRoute } - - return Route{}, tcpip.ErrNoRoute + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1470,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { // FindTransportEndpoint finds an endpoint that most closely matches the provided // id. If no endpoint is found it returns nil. -func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { - return s.demux.findTransportEndpoint(netProto, transProto, id, r) +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, nicID) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1923,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { return tcpip.NewJob(s.clock, l, f) } + +// ParseResult indicates the result of a parsing attempt. +type ParseResult int + +const ( + // ParsedOK indicates that a packet was successfully parsed. + ParsedOK ParseResult = iota + + // UnknownNetworkProtocol indicates that the network protocol is unknown. + UnknownNetworkProtocol + + // NetworkLayerParseError indicates that the network packet was not + // successfully parsed. + NetworkLayerParseError + + // UnknownTransportProtocol indicates that the transport protocol is unknown. + UnknownTransportProtocol + + // TransportLayerParseError indicates that the transport packet was not + // successfully parsed. + TransportLayerParseError +) + +// ParsePacketBuffer parses the provided packet buffer. +func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return UnknownNetworkProtocol + } + + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + return NetworkLayerParseError + } + if !hasTransportHdr { + return ParsedOK + } + + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + return ParsedOK + } + + pkt.TransportProtocolNumber = transProtoNum + // Parse the transport header if present. + state, ok := s.transportProtocols[transProtoNum] + if !ok { + return UnknownTransportProtocol + } + + if !state.proto.Parse(pkt) { + return TransportLayerParseError + } + + return ParsedOK +} + +// networkProtocolNumbers returns the network protocol numbers the stack is +// configured with. +func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { + protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols)) + for p := range s.networkProtocols { + protos = append(protos, p) + } + return protos +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 4eed4ced4..dedfdd435 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "net" "sort" "testing" "time" @@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ + netHdr := pkt.NetworkHeader().View() + f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { + if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + pkt.NetworkProtocolNumber = fakeNetNumber hdr[dstAddrOffset] = r.RemoteAddress[0] hdr[srcAddrOffset] = r.LocalAddress[0] hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - f.HandlePacket(r, pkt) + pkt := pkt.Clone() + r.PopulatePacketInfo(pkt) + f.HandlePacket(pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto if !ok { return 0, false, false } + pkt.NetworkProtocolNumber = fakeNetNumber return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) { testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } +// TestExternalSendWithHandleLocal tests that the stack creates a non-local +// route when spoofing or promiscuous mode are enabled. +// +// This test makes sure that packets are transmitted from the stack. +func TestExternalSendWithHandleLocal(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + + localAddr = tcpip.Address("\x01") + dstAddr = tcpip.Address("\x03") + ) + + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + configureStack func(*testing.T, *stack.Stack) + }{ + { + name: "Default", + configureStack: func(*testing.T, *stack.Stack) {}, + }, + { + name: "Spoofing", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetSpoofing(nicID, true); err != nil { + t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) + } + }, + }, + { + name: "Promiscuous", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetPromiscuousMode(nicID, true); err != nil { + t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, handleLocal := range []bool{true, false} { + t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + HandleLocal: handleLocal, + }) + + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) + + test.configureStack(t, s) + + r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) + } + defer r.Release() + + if r.LocalAddress != localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + + if n := ep.Drain(); n != 0 { + t.Fatalf("got ep.Drain() = %d, want = 0", n) + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: fakeTransNumber, + TTL: 123, + TOS: stack.DefaultTOS, + }, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.NewView(10).ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, _, _): %s", err) + } + if n := ep.Drain(); n != 1 { + t.Fatalf("got ep.Drain() = %d, want = 1", n) + } + }) + } + }) + } +} + func TestSpoofingWithAddress(t *testing.T) { localAddr := tcpip.Address("\x01") nonExistentLocalAddr := tcpip.Address("\x02") @@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { RemoteAddress: ipv4SubnetBcast, RemoteLinkAddress: header.EthernetBroadcastAddress, NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, + Loop: stack.PacketOut | stack.PacketLoop, }, }, // Broadcast to a locally attached /31 subnet does not populate the @@ -3756,3 +3862,369 @@ func TestRemoveRoutes(t *testing.T) { } } } + +func TestFindRouteWithForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Addr = tcpip.Address("\x01") + nic2Addr = tcpip.Address("\x02") + remoteAddr = tcpip.Address("\x03") + ) + + type netCfg struct { + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1Addr tcpip.Address + nic2Addr tcpip.Address + remoteAddr tcpip.Address + } + + fakeNetCfg := netCfg{ + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1Addr: nic1Addr, + nic2Addr: nic2Addr, + remoteAddr: remoteAddr, + } + + globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) + globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) + + ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: llAddr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: globalIPv6Addr1, + } + ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: llAddr1, + remoteAddr: llAddr2, + } + ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + } + + tests := []struct { + name string + + netCfg netCfg + forwardingEnabled bool + + addrNIC tcpip.NICID + localAddr tcpip.Address + + findRouteErr *tcpip.Error + dependentOnForwarding bool + }{ + { + name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: true, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) + } + + if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + } + + if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + } + + if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) + + r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if err != test.findRouteErr { + t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + } + defer r.Release() + + if test.findRouteErr != nil { + return + } + + if r.LocalAddress != test.localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr) + } + if r.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr) + } + + if t.Failed() { + t.FailNow() + } + + // Sending a packet should always go through NIC2 since we only install a + // route to test.netCfg.remoteAddr through NIC2. + data := buffer.View([]byte{1, 2, 3, 4}) + if err := send(r, data); err != nil { + t.Fatalf("send(_, _): %s", err) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not sent through ep2") + } + if pkt.Route.LocalAddress != test.localAddr { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + } + if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) + } + + if !test.forwardingEnabled || !test.dependentOnForwarding { + return + } + + // Disabling forwarding when the route is dependent on forwarding being + // enabled should make the route invalid. + if err := s.SetForwarding(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + } + if err := send(r, data); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + if n := ep2.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep2", n) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 35e5b1a2e..f183ec6e4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isInboundMulticastOrBroadcast(r) { - mpep.handlePacketAll(r, id, pkt) + if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { + mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { - queuedProtocol.QueuePacket(r, transEP, id, pkt) + queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() return } - transEP.HandlePacket(r, id, pkt) + transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } @@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T // based on endpoints IDs. It should only be instantiated via // newTransportDemuxer. type transportDemuxer struct { + stack *Stack + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints queuedProtocols map[protocolIDs]queuedTransportProtocol @@ -262,11 +264,12 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) + QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { d := &transportDemuxer{ + stack: stack, protocol: make(map[protocolIDs]*transportEndpoints), queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), } @@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) } else { - endpoint.HandlePacket(r, id, pkt.Clone()) + endpoint.HandlePacket(id, pkt.Clone()) } } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) + queuedProtocol.QueuePacket(endpoint, id, pkt) } else { - endpoint.HandlePacket(r, id, pkt) + endpoint.HandlePacket(id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() return false } // handlePacket takes ownership of pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(r, id, pkt.Clone()) + ep.handlePacket(id, pkt.Clone()) } - destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + destEPs[len(destEPs)-1].handlePacket(id, pkt) return true } @@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // destination address, then do nothing further and instruct the caller to do // the same. The network layer handles address validation for specified source // addresses. - if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { + if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { // TCP can only be used to communicate between a single source and a - // single destination; the addresses must be unicast. - r.Stats().TCP.InvalidSegmentsReceived.Increment() + // single destination; the addresses must be unicast.e + d.stack.stats.TCP.InvalidSegmentsReceived.Increment() return true } @@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() } return false } - ep.handlePacket(r, id, pkt) + ep.handlePacket(id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } @@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, pkt) + rawEP.HandlePacket(pkt.Clone()) foundRaw = true } eps.mu.RUnlock() @@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // findTransportEndpoint find a single endpoint that most closely matches the provided id. -func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil @@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[nicID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isInboundMulticastOrBroadcast(r *Route) bool { - return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) +func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { + return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) } func isSpecified(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 6b8071467..c457b67a2 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) + if f.acceptQueue == nil { + return } + + netHdr := pkt.NetworkHeader().View() + route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return + } + route.ResolveWith(pkt.SourceLinkAddress()) + + f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, + proto: f.proto, + peerAddr: route.RemoteAddress, + route: route, + }) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { @@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 34aab32d0..9b0f3b675 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -10,6 +10,7 @@ go_test( "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", + "route_test.go", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 0dcef7b04..bf7594268 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -33,11 +33,6 @@ import ( func TestForwarding(t *testing.T) { const ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") - routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - host1NICID = 1 routerNICID1 = 2 routerNICID2 = 3 @@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) { } }, }, + { + name: "IPv4 host2 server with routerNIC1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host2IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + { + name: "IPv6 routerNIC2 server with host1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host1IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, } for _, test := range tests { @@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -321,12 +348,8 @@ func TestForwarding(t *testing.T) { if err == tcpip.ErrNoLinkAddress { // Wait for link resolution to complete. <-ch - n, _, err = ep.Write(dataPayload, wOpts) - } else if err != nil { - t.Fatalf("ep.Write(_, _): %s", err) } - if err != nil { t.Fatalf("ep.Write(_, _): %s", err) } @@ -343,7 +366,6 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress v, _, err := ep.Read(&addr) if err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 6ddcda70c..fe7c1bb3d 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -32,32 +32,36 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +const ( + linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +) - host1IPv4Addr = tcpip.ProtocolAddress{ +var ( + ipv4Addr1 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - host2IPv4Addr = tcpip.ProtocolAddress{ + ipv4Addr2 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr1 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr2 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), @@ -89,7 +93,7 @@ func TestPing(t *testing.T) { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, - remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) @@ -104,7 +108,7 @@ func TestPing(t *testing.T) { name: "IPv6 Ping", transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, - remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) @@ -127,7 +131,7 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) + host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -143,36 +147,36 @@ func TestPing(t *testing.T) { t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, tcpip.Route{ - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, }) host2Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, tcpip.Route{ - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, }) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index e8caf09ba..421da1add 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -204,7 +204,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { }, }) - wq := waiter.Queue{} + var wq waiter.Queue rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index f1028823b..cdf0459e3 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -409,7 +409,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - wq := waiter.Queue{} + var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) @@ -447,8 +447,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) { loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") ) - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) - tests := []struct { name string broadcastAddr tcpip.Address @@ -492,16 +490,22 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, }) + type endpointAndWaiter struct { + ep tcpip.Endpoint + ch chan struct{} + } + var eps []endpointAndWaiter // We create endpoints that bind to both the wildcard address and the // broadcast address to make sure both of these types of "broadcast // interested" endpoints receive broadcast packets. - wq := waiter.Queue{} - var eps []tcpip.Endpoint for _, bindWildcard := range []bool{false, true} { // Create multiple endpoints for each type of "broadcast interested" // endpoint so we can test that all endpoints receive the broadcast // packet. for i := 0; i < 2; i++ { + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) @@ -528,7 +532,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } } - eps = append(eps, ep) + eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) } } @@ -539,14 +543,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - if n, _, err := wep.Write(data, writeOpts); err != nil { + data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) + if n, _, err := wep.ep.Write(data, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) } for j, rep := range eps { - if gotPayload, _, err := rep.Read(nil); err != nil { + // Wait for the endpoint to become readable. + <-rep.ch + + if gotPayload, _, err := rep.ep.Read(nil); err != nil { t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go new file mode 100644 index 000000000..02fc47015 --- /dev/null +++ b/pkg/tcpip/tests/integration/route_test.go @@ -0,0 +1,388 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TestLocalPing tests pinging a remote that is local the stack. +// +// This tests that a local route is created and packets do not leave the stack. +func TestLocalPing(t *testing.T) { + const ( + nicID = 1 + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo + // request/reply packets. + icmpDataOffset = 8 + ) + + channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } + channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { + channelEP := e.(*channel.Endpoint) + if n := channelEP.Drain(); n != 0 { + t.Fatalf("got channelEP.Drain() = %d, want = 0", n) + } + } + + ipv4ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) + hdr.SetType(header.ICMPv4Echo) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + ipv6ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} + hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) + hdr.SetType(header.ICMPv6EchoRequest) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + tests := []struct { + name string + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber + linkEndpoint func() stack.LinkEndpoint + localAddr tcpip.Address + icmpBuf func(*testing.T) buffer.View + expectedConnectErr *tcpip.Error + checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) + }{ + { + name: "IPv4 loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: ipv4Loopback, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: header.IPv6Loopback, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv4Addr.Address, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv6Addr.Address, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv4 loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() + + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } + + if test.expectedConnectErr != nil { + return + } + + payload := tcpip.SlicePayload(test.icmpBuf(t)) + var wOpts tcpip.WriteOptions + if n, _, err := ep.Write(payload, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } + + // Wait for the endpoint to become readable. + <-ch + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } +} + +// TestLocalUDP tests sending UDP packets between two endpoints that are local +// to the stack. +// +// This tests that that packets never leave the stack and the addresses +// used when sending a packet. +func TestLocalUDP(t *testing.T) { + const ( + nicID = 1 + ) + + tests := []struct { + name string + canBePrimaryAddr tcpip.ProtocolAddress + firstPrimaryAddr tcpip.ProtocolAddress + }{ + { + name: "IPv4", + canBePrimaryAddr: ipv4Addr1, + firstPrimaryAddr: ipv4Addr2, + }, + { + name: "IPv6", + canBePrimaryAddr: ipv6Addr1, + firstPrimaryAddr: ipv6Addr2, + }, + } + + subTests := []struct { + name string + addAddress bool + expectedWriteErr *tcpip.Error + }{ + { + name: "Unassigned local address", + addAddress: false, + expectedWriteErr: tcpip.ErrNoRoute, + }, + { + name: "Assigned local address", + addAddress: true, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + s := stack.New(stackOpts) + ep := channel.New(1, header.IPv6MinimumMTU, "") + + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if subTest.addAddress { + if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + } + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer server.Close() + + bindAddr := tcpip.FullAddress{Port: 80} + if err := server.Bind(bindAddr); err != nil { + t.Fatalf("server.Bind(%#v): %s", bindAddr, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventIn) + client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer client.Close() + + serverAddr := tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + Port: 80, + } + + clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &serverAddr, + } + if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) + } else if subTest.expectedWriteErr != nil { + // Nothing else to test if we expected not to be able to send the + // UDP packet. + return + } else if n != int64(len(clientPayload)) { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload)) + } + } + + // Wait for the server endpoint to become readable. + <-serverCH + + var clientAddr tcpip.FullAddress + if v, _, err := server.Read(&clientAddr); err != nil { + t.Fatalf("server.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + } + if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { + t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + } + if t.Failed() { + t.FailNow() + } + } + + serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &clientAddr, + } + if n, _, err := server.Write(serverPayload, wOpts); err != nil { + t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) + } else if n != int64(len(serverPayload)) { + t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload)) + } + } + + // Wait for the client endpoint to become readable. + <-clientCH + + var gotServerAddr tcpip.FullAddress + if v, _, err := client.Read(&gotServerAddr); err != nil { + t.Fatalf("client.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + } + if gotServerAddr.Addr != serverAddr.Addr { + t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + } + if t.Failed() { + t.FailNow() + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a17234946..763cd8f84 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -755,7 +755,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -800,7 +800,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, }, } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 87d510f96..3820e5dc7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 79f688129..7b6a87ba9 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -646,7 +646,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -671,14 +671,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { return } + remoteAddr := pkt.Network().SourceAddress() + if e.bound { // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != route.NICID() { + if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() return } @@ -686,7 +688,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // If connected, only accept packets from the remote address we // connected to. - if e.connected && e.route.RemoteAddress != route.RemoteAddress { + if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() return } @@ -696,8 +698,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // Push new packet into receive list and increment the buffer size. packet := &rawPacket{ senderAddr: tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress, + NIC: pkt.NICID, + Addr: remoteAddr, }, } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6b3238d6b..47982ca41 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { - netProto = s.route.NetProto + netProto = s.netProto } + + route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(s.remoteLinkAddr) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6Only n.ID = s.id - n.boundNICID = s.route.NICID() - n.route = s.route.Clone() - n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.boundNICID = s.nicID + n.route = route + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} n.rcvBufSize = int(l.rcvWnd) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - return n + return n, nil } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + if err != nil { + return nil, err + } // Lock the endpoint before registering to ensure that no out of // band changes are possible due to incoming packets etc till @@ -467,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -477,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } switch { @@ -492,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. - return + return nil } ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } else { // If cookies are in use but the endpoint accept queue // is full then drop the syn. @@ -506,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Send SYN without window scaling because we currently // don't encode this information in the cookie. // @@ -523,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TS: opts.TS, TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, s.route), + MSS: calculateAdvertisedMSS(e.userMSS, route), } - e.sendSynTCP(&s.route, tcpFields{ + fields := tcpFields{ id: s.id, ttl: e.ttl, tos: e.sendTOS, @@ -533,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { seq: cookie, ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, - }, synOpts) + } + if err := e.sendSynTCP(&route, fields, synOpts); err != nil { + return err + } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil } case (s.flags & header.TCPFlagAck) != 0: @@ -547,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } if !ctx.synRcvdCount.synCookiesInUse() { @@ -566,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } iss := s.ackNumber - 1 @@ -587,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -608,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + if err != nil { + return err + } n.mu.Lock() @@ -622,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return nil } // Register new endpoint so that packets are routed to it. @@ -632,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return err } n.isRegistered = true @@ -670,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) + return nil + + default: + return nil } } // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. -func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() v6Only := e.v6only ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) @@ -714,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { - return nil + return } if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { s := e.segmentQueue.dequeue() - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } close(e.drainDone) @@ -738,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { break } - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index c890e2326..2facbebec 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { MSS: amss, } if ttl == 0 { - ttl = s.route.DefaultTTL() + ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -771,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // TCP header, then the kernel calculate a checksum of the // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) - } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + } else if r.RequiresTXTransportChecksum() { xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } @@ -1044,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } if ep == nil { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) s.decRef() return } @@ -1626,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} } for _, netProto := range netProtos { - if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil { tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 98aecab9e..21162f01a 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -172,10 +172,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newSegment(r, id, pkt) - if !s.parse() { + + s := newIncomingSegment(id, pkt) + if !s.parse(pkt.RXTransportChecksumValidated) { ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b817ab6ef..258f9f1bb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) + s := newOutgoingSegment(e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -2316,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // done yet) or the reservation was freed between the check above and // the FindTransportEndpoint below. But rather than retry the same port // we just skip it and move on. - transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID()) if transEP == nil { // ReservePort failed but there is no registered endpoint with // demuxer. Which indicates there is at least some endpoint that has @@ -2385,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.ID - s.route = r.Clone() e.sndWaker.Assert() } } @@ -2451,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.ID, nil) + s := newOutgoingSegment(e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ // Mark endpoint as closed. @@ -2723,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { } } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -3082,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() { } func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() - } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + } else if e.route.HasSoftwareGSOCapability() { e.gso = &stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b25431467..2bcc5e1c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() { switch { case epState == StateInitial || epState == StateBound: case epState.connected() || epState.handshake(): - if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { - if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { + if !e.route.HasSaveRestoreCapability() { + if !e.route.HasDisconncetOkCapability() { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 070b634b4..0664789da 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -30,6 +30,8 @@ import ( // The canonical way of using it is to pass the Forwarder.HandlePacket function // to stack.SetTransportProtocolHandler. type Forwarder struct { + stack *stack.Stack + maxInFlight int handler func(*ForwarderRequest) @@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward rcvWnd = DefaultReceiveBufferSize } return &Forwarder{ + stack: s, maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), @@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newSegment(r, id, pkt) +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newIncomingSegment(id, pkt) defer s.decRef() // We only care about well-formed SYN packets. - if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn { return false } @@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) { delete(r.forwarder.inFlight, r.segment.id) r.forwarder.mu.Unlock() - // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) + replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */) } // Release all resources. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 5bce73605..2329aca4b 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(r, ep, id, pkt) +func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + p.dispatcher.queuePacket(ep, id, pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." - -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newSegment(r, id, pkt) +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + s := newIncomingSegment(id, pkt) defer s.decRef() - if !s.parse() || !s.csumValid { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } if !s.flagIsSet(header.TCPFlagRst) { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(p.stack, s, stack.DefaultTOS, 0) } return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment, tos, ttl uint8) { +// +// If the passed TTL is 0, then the route's default TTL will be used. +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { + route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) @@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) { flags |= header.TCPFlagAck ack = s.sequenceNumber.Add(s.logicalLen()) } - sendTCP(&s.route, tcpFields{ + + if ttl == 0 { + ttl = route.DefaultTTL() + } + + return sendTCP(&route, tcpFields{ id: s.id, ttl: ttl, tos: tos, diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 1f9c5cf50..2091989cc 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -45,9 +46,18 @@ type segment struct { ep *endpoint qFlags queueFlags id stack.TransportEndpointID `state:"manual"` - route stack.Route `state:"manual"` - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - hdr header.TCP + + // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of + // individual members for link/network packet info. + srcAddr tcpip.Address + dstAddr tcpip.Address + netProto tcpip.NetworkProtocolNumber + nicID tcpip.NICID + remoteLinkAddr tcpip.LinkAddress + + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -76,11 +86,16 @@ type segment struct { acked bool } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { + netHdr := pkt.Network() s := &segment{ - refCnt: 1, - id: id, - route: r.Clone(), + refCnt: 1, + id: id, + srcAddr: netHdr.SourceAddress(), + dstAddr: netHdr.DestinationAddress(), + netProto: pkt.NetworkProtocolNumber, + nicID: pkt.NICID, + remoteLinkAddr: pkt.SourceLinkAddress(), } s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) @@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB return s } -func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, - route: r.Clone(), } s.rcvdTime = time.Now() if len(v) != 0 { @@ -110,7 +124,9 @@ func (s *segment) clone() *segment { ackNumber: s.ackNumber, flags: s.flags, window: s.window, - route: s.route.Clone(), + netProto: s.netProto, + nicID: s.nicID, + remoteLinkAddr: s.remoteLinkAddr, viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, @@ -160,7 +176,6 @@ func (s *segment) decRef() { panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) } } - s.route.Release() } } @@ -198,10 +213,10 @@ func (s *segment) segMemSize() int { // // Returns boolean indicating if the parsing was successful. // -// If checksum verification is not offloaded then parse also verifies the +// If checksum verification may not be skipped, parse also verifies the // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. -func (s *segment) parse() bool { +func (s *segment) parse(skipChecksumValidation bool) bool { // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -220,16 +235,14 @@ func (s *segment) parse() bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - // Query the link capabilities to decide if checksum validation is - // required. verifyChecksum := true - if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { + if skipChecksumValidation { s.csumValid = true verifyChecksum = false } if verifyChecksum { s.csum = s.hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cdb5127ab..56bdf6c34 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1012,7 +1012,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + if r.RequiresTXTransportChecksum() && (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { @@ -1382,10 +1382,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // On IPv4, UDP checksum is optional, and a zero value means the transmitter // omitted the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). -func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool { - if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && - (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) +func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { + if !pkt.RXTransportChecksumValidated && + (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { + netHdr := pkt.Network() + xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) } @@ -1396,7 +1397,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { @@ -1406,7 +1407,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - if !verifyChecksum(r, hdr, pkt) { + if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() @@ -1437,7 +1438,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, Port: header.UDP(hdr).SourcePort(), }, @@ -1447,7 +1448,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvBufSize += pkt.Data.Size() // Save any useful information from the network header to the packet. - switch r.NetProto { + switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: @@ -1457,9 +1458,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast // address. packetInfo.LocalAddr should hold a unicast address that can be // used to respond to the incoming packet. - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.LocalAddress - packet.packetInfo.NIC = r.NICID() + localAddr := pkt.Network().DestinationAddress() + packet.packetInfo.LocalAddr = localAddr + packet.packetInfo.DestinationAddr = localAddr + packet.packetInfo.NIC = pkt.NICID packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 3ae6cc221..14e4648cd 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, - route: r, id: id, pkt: pkt, }) @@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p // it via CreateEndpoint. type ForwarderRequest struct { stack *stack.Stack - route *stack.Route id stack.TransportEndpointID pkt *stack.PacketBuffer } @@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + netHdr := r.pkt.Network() + route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(r.pkt.SourceLinkAddress()) + + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() + route.Release() return nil, err } ep.ID = r.id - ep.route = r.route.Clone() + ep.route = route ep.dstPort = r.id.RemotePort - ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto} - ep.RegisterNICID = r.route.NICID() + ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} + ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags ep.state = StateConnected @@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.pkt) + ep.HandlePacket(r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index da5b1deb2..91420edd3 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets that are targeted at this // protocol but don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + p.stack.Stats().UDP.MalformedPacketsReceived.Increment() return stack.UnknownDestinationPacketMalformed } - if !verifyChecksum(r, hdr, pkt) { - r.Stack().Stats().UDP.ChecksumErrors.Increment() + if !verifyChecksum(hdr, pkt) { + p.stack.Stats().UDP.ChecksumErrors.Increment() return stack.UnknownDestinationPacketMalformed } |