From 1a972411b36b8ad2543d3ea614c92e60ccbdffab Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 12 Nov 2020 17:30:31 -0800 Subject: Move packet handling to NetworkEndpoint The NIC should not hold network-layer state or logic - network packet handling/forwarding should be performed at the network layer instead of the NIC. Fixes #4688 PiperOrigin-RevId: 342166985 --- pkg/tcpip/network/BUILD | 1 + pkg/tcpip/network/arp/arp_test.go | 4 + pkg/tcpip/network/ip_test.go | 289 ++++++++++++++++++++---------------- pkg/tcpip/network/ipv4/ipv4.go | 106 ++++++++++--- pkg/tcpip/network/ipv6/icmp_test.go | 73 ++++----- pkg/tcpip/network/ipv6/ipv6.go | 102 ++++++++++--- pkg/tcpip/network/ipv6/ndp_test.go | 34 ++--- pkg/tcpip/stack/forwarding_test.go | 42 +++++- pkg/tcpip/stack/nic.go | 93 +----------- pkg/tcpip/stack/registration.go | 7 +- pkg/tcpip/stack/route.go | 22 --- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/stack/stack_test.go | 96 ++---------- pkg/tcpip/stack/transport_test.go | 85 ----------- pkg/tcpip/tcpip.go | 6 +- pkg/tcpip/transport/udp/udp_test.go | 4 + 16 files changed, 442 insertions(+), 524 deletions(-) (limited to 'pkg') diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index c118a2929..b38aff0b8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -14,6 +14,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 087ee9c66..91d59d83f 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -439,6 +439,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ NetProto: protocol, diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index c7d26e14f..576792234 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -34,16 +35,16 @@ import ( ) const ( - localIPv4Addr = "\x0a\x00\x00\x01" - remoteIPv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") + ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") + ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") + ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") + localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") + ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") nicID = 1 ) @@ -299,6 +300,10 @@ func (t *testInterface) Enabled() bool { return !t.mu.disabled } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) setEnabled(v bool) { t.mu.Lock() defer t.mu.Unlock() @@ -558,59 +563,131 @@ func TestIPv4Send(t *testing.T) { } } -func TestIPv4Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, +func TestReceive(t *testing.T) { + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + v4 bool + epAddr tcpip.AddressWithPrefix + handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + v4: true, + epAddr: localIPv4Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const totalLen = header.IPv4MinimumSize + 30 /* payload length */ + + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + TTL: ipv4.DefaultTTL, + Protocol: 10, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if ok := parse.IPv4(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + v4: false, + epAddr: localIPv6Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const payloadLen = 30 + view := buffer.NewView(header.IPv6MinimumSize + payloadLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadLen, + NextHeader: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, + }) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, _, _, ok := parse.IPv6(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, + }, } - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: test.v4, + }, + } + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + if ep, err := ep.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + } else { + ep.DecRef() + } + + stat := s.Stats().IP.PacketsReceived + if got := stat.Value(); got != 0 { + t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) + } + test.handlePacket(t, ep, &nic) + if nic.testObject.dataCalls != 1 { + t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + } + if got := stat.Value(); got != 1 { + t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) + } + }) } } @@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -705,8 +778,14 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra + addr := localIPv4Addr.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) @@ -716,7 +795,9 @@ func TestIPv4ReceiveControl(t *testing.T) { } func TestIPv4FragmentationReceive(t *testing.T) { - s := buildDummyStack(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ testObject: testObject{ @@ -774,11 +855,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - // Send first segment. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), @@ -786,7 +862,14 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) + + addr := localIPv4Addr.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) @@ -799,7 +882,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) @@ -852,61 +934,6 @@ func TestIPv6Send(t *testing.T) { } } -func TestIPv6Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } -} - func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } @@ -933,13 +960,6 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } - r, err := buildIPv6Route( - localIPv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -1018,8 +1038,13 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) + addr := localIPv6Addr.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a376cb8ec..0af646df9 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -260,16 +260,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt) -} -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -286,24 +283,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -374,8 +374,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -400,9 +399,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -479,16 +479,66 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt) + return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv4(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -497,6 +547,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + + addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + subnet := addressEndpoint.AddressWithPrefix().Subnet() + addressEndpoint.DecRef() // There has been some confusion regarding verifying checksums. We need // just look for negative 0 (0xffff) as the checksum, as it's not possible to @@ -528,15 +593,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { stats.IP.InvalidSourceAddressesReceived.Increment() return } + pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 76013daa1..001b9d66a 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ NetProto: protocol, @@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -206,11 +205,12 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addr := lladdr0.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -279,10 +279,9 @@ func TestICMPCounts(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -290,7 +289,7 @@ func TestICMPCounts(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -317,13 +316,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -349,11 +343,12 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addr := lladdr0.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -422,10 +417,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -433,7 +427,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -1775,17 +1769,15 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addr := lladdr0.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() - - // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. - r.LocalAddress = test.destination icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.IPv6MinimumSize, Data: buffer.View(icmp).ToVectorisedView(), @@ -1795,10 +1787,9 @@ func TestCallsToNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.RemoteAddress, - DstAddr: r.LocalAddress, + SrcAddr: test.source, + DstAddr: test.destination, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0526190cc..38a0633bd 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt, params.Protocol) -} -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -640,16 +639,66 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt, proto) + return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv6(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -669,6 +718,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + addressEndpoint.DecRef() + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. @@ -681,8 +742,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 981d1371a..be83e9eb4 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -73,6 +69,13 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig } t.Cleanup(ep.Close) + addr := llladdr.WithPrefix() + if addressEP, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + addressEP.DecRef() + } + return s, ep } @@ -961,22 +964,17 @@ func TestNDPValidation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() // Create a stack with the assigned link-local address lladdr0 // and an endpoint to lladdr1. s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r + return s, ep } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { nextHdr := uint8(header.ICMPv6ProtocolNumber) var extensions buffer.View if atomicFragment { @@ -994,13 +992,12 @@ func TestNDPValidation(t *testing.T) { PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -1114,8 +1111,7 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + s, ep := setup(t) if isRouter { // Enabling forwarding makes the stack act as a router. @@ -1131,7 +1127,7 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -1152,7 +1148,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7a501acdc..6dc9e7859 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + netHdr := pkt.NetworkHeader().View() + _, dst := f.proto.ParseAddresses(netHdr) + + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) + return + } + + r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) + if err != nil { + return + } + defer r.Release() + + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + pkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: vv.ToView().ToVectorisedView(), + }) + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + _ = r.WriteHeaderIncludedPacket(pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + // The network header should not already be populated. + if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { + return tcpip.ErrMalformedHeader + } + + return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() { // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { + stack *Stack + addrCache *linkAddrCache neigh *neighborCache addrResolveDelay time.Duration @@ -334,7 +363,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { + proto.stack = s + return proto + }}, UseNeighborCache: useNeighborCache, }) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 60c81a3aa..12134238e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -232,7 +232,8 @@ func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Unlock() } -func (n *NIC) isPromiscuousMode() bool { +// Promiscuous implements NetworkInterface. +func (n *NIC) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -564,13 +565,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return false } -func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) - defer r.Release() - r.PopulatePacketInfo(pkt) - n.getNetworkEndpoint(protocol).HandlePacket(pkt) -} - // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the link endpoint. @@ -592,7 +586,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) - netProto, ok := n.stack.networkProtocols[protocol] + networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -617,11 +611,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp ep.HandlePacket(n.id, local, protocol, p) } - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() - } - // Parse headers. + netProto := n.stack.NetworkProtocolInstance(protocol) transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { // The packet is too small to contain a network header. @@ -636,9 +627,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.IsLoopback() { + src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) if r := n.getAddress(protocol, src); r != nil { r.DecRef() @@ -651,78 +641,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - // Loopback traffic skips the prerouting chain. - if !n.IsLoopback() { - // iptables filtering. - ipt := n.stack.IPTables() - address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { - // iptables is telling us to drop the packet. - n.stack.stats.IP.IPTablesPreroutingDropped.Increment() - return - } - } - - if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { - n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) - return - } - - // This NIC doesn't care about the packet. Find a NIC that cares about the - // packet and forward it to the NIC. - // - // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding(protocol) { - r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) - if err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - // Found a NIC. - n := r.localAddressNIC - if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { - if n.isValidForOutgoing(addressEndpoint) { - pkt.NICID = n.ID() - r.RemoteAddress = src - pkt.NetworkPacketInfo = r.networkPacketInfo() - n.getNetworkEndpoint(protocol).HandlePacket(pkt) - addressEndpoint.DecRef() - r.Release() - return - } - - addressEndpoint.DecRef() - } - - // n doesn't have a destination endpoint. - // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease - // the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - // We need to do a deep copy of the IP packet because WritePacket (and - // friends) take ownership of the packet buffer, but we do not own it. - Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } - - r.Release() - return - } - - // If a packet socket handled the packet, don't treat it as invalid. - if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } + networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 00e9a82ae..1b1603bee 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -65,10 +65,6 @@ const ( // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { - // RemoteAddressBroadcast is true if the packet's remote address is a - // broadcast address. - RemoteAddressBroadcast bool - // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool @@ -518,6 +514,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + // Promiscuous returns true if the interface is in promiscuous mode. + Promiscuous() bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 15ff437c7..53cb6694f 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -170,28 +170,6 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -// PopulatePacketInfo populates a packet buffer's packet information fields. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { - if r.local() { - pkt.RXTransportChecksumValidated = true - } - pkt.NetworkPacketInfo = r.networkPacketInfo() -} - -// networkPacketInfo returns the network packet information of the route. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) networkPacketInfo() NetworkPacketInfo { - return NetworkPacketInfo{ - RemoteAddressBroadcast: r.IsOutboundBroadcast(), - LocalAddressBroadcast: r.isInboundBroadcast(), - } -} - // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { return r.outgoingNIC.ID() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0fe157128..ca9f41af4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1080,7 +1080,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. Running: nic.Enabled(), - Promiscuous: nic.isPromiscuousMode(), + Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dedfdd435..61db3164b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -112,7 +112,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() - f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ + + dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + return + } + addressEndpoint.DecRef() + + f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { @@ -159,9 +167,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.Clone() - r.PopulatePacketInfo(pkt) - f.HandlePacket(pkt) + f.HandlePacket(pkt.Clone()) } if r.Loop&stack.PacketOut == 0 { return nil @@ -2214,88 +2220,6 @@ func TestNICStats(t *testing.T) { } } -func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") - - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } - - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } - - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { - t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index c457b67a2..c1c3d8541 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -553,87 +552,3 @@ func TestTransportOptions(t *testing.T) { t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: req.ToVectorisedView(), - })) - - aep, _, err := ep.Accept(nil) - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - p, ok := ep2.Read() - if !ok { - t.Fatal("Response packet not forwarded") - } - - nh := stack.PayloadSince(p.Pkt.NetworkHeader()) - if dst := nh[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := nh[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3ab2b7654..9a0c63ae4 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1464,9 +1464,13 @@ type ICMPStats struct { // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // PacketsReceived is the total number of IP packets received from the - // link layer in nic.DeliverNetworkPacket. + // link layer. PacketsReceived *StatCounter + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived *StatCounter + // InvalidDestinationAddressesReceived is the total number of IP packets // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived *StatCounter diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c09c7aa86..df62177cd 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1451,6 +1451,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -- cgit v1.2.3