summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go5
-rw-r--r--pkg/tcpip/network/ip_test.go14
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go103
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go71
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go103
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go4
-rw-r--r--pkg/tcpip/stack/nic.go30
-rw-r--r--pkg/tcpip/stack/stack.go44
-rw-r--r--pkg/tcpip/stack/stack_test.go4
11 files changed, 201 insertions, 192 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 0caa65251..fa8814bac 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -16,7 +16,6 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 0d7fadc31..bd9b9c020 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -129,6 +129,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ if _, _, ok := e.protocol.Parse(pkt); !ok {
+ stats.malformedPacketsReceived.Increment()
+ return
+ }
+
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
stats.malformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6a1f11a36..a176ef2b9 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -626,9 +625,6 @@ func TestReceive(t *testing.T) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: view.ToVectorisedView(),
})
- if ok := parse.IPv4(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
ep.HandlePacket(pkt)
},
},
@@ -664,9 +660,6 @@ func TestReceive(t *testing.T) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: view.ToVectorisedView(),
})
- if _, _, _, _, ok := parse.IPv6(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
ep.HandlePacket(pkt)
},
},
@@ -943,9 +936,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag1.ToVectorisedView(),
})
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
if !ok {
@@ -967,9 +957,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag2.ToVectorisedView(),
})
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
@@ -1234,7 +1221,6 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
})
- _, _ = pkt.NetworkHeader().Consume(netHdrLen)
return pkt
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b2d626107..e1e05e39c 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -347,15 +347,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
- if err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
+ if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -365,14 +360,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // If the packet was generated by the stack (not a raw/packet endpoint
- // where a packet may be written with the header included), then we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = !headerIncluded
- e.handlePacket(pkt)
- }
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -471,14 +462,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
+ if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
n++
continue
}
@@ -573,14 +560,10 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
- networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
- if err == nil {
- networkEndpoint.(*endpoint).handlePacket(pkt)
+ if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil {
+ ep.(*endpoint).handlePacket(pkt)
return nil
}
- if _, ok := err.(*tcpip.ErrBadAddress); !ok {
- return err
- }
r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -619,8 +602,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Loopback traffic skips the prerouting chain.
+ if !e.protocol.parse(pkt) {
+ stats.MalformedPacketsReceived.Increment()
+ return
+ }
+
if !e.nic.IsLoopback() {
+ if e.protocol.stack.HandleLocal() {
+ addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+
+ // The source address is one of our own, so we never should have gotten
+ // a packet like this unless HandleLocal is false or our NIC is the
+ // loopback interface.
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
+
+ // Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
@@ -632,6 +633,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.handlePacket(pkt)
}
+func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
+
+ pkt = pkt.CloneToInbound()
+ if e.protocol.parse(pkt) {
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+ e.handlePacket(pkt)
+ return
+ }
+
+ stats.MalformedPacketsReceived.Increment()
+}
+
// handlePacket is like HandlePacket except it does not perform the prerouting
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
@@ -1043,6 +1059,29 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// parse is like Parse but also attempts to parse the transport layer.
+//
+// Returns true if the network header was successfully parsed.
+func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+ transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
+ if !ok {
+ return false
+ }
+
+ if hasTransportHdr {
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+ }
+
+ return true
+}
+
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
if ok := parse.IPv4(pkt); !ok {
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 92f9ee2c2..ca46ec61f 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -149,6 +149,23 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber,
return nil
}
+func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) {
+ ip := buffer.NewView(header.IPv6MinimumSize)
+ header.IPv6(ip).Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ vv := ip.ToVectorisedView()
+ vv.AppendView(buffer.View(icmp))
+ ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: vv,
+ }))
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -282,33 +299,17 @@ func TestICMPCounts(t *testing.T) {
},
}
- handleIPv6Payload := func(icmp header.ICMPv6) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize,
- Data: buffer.View(icmp).ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- ep.HandlePacket(pkt)
- }
-
for _, typ := range types {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
- handleIPv6Payload(icmp)
+ handleICMPInIPv6(ep, lladdr1, lladdr0, icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+ handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -440,33 +441,17 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
},
}
- handleIPv6Payload := func(icmp header.ICMPv6) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize,
- Data: buffer.View(icmp).ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- ep.HandlePacket(pkt)
- }
-
for _, typ := range types {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
- handleIPv6Payload(icmp)
+ handleICMPInIPv6(ep, lladdr1, lladdr0, icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+ handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -1818,19 +1803,7 @@ func TestCallsToNeighborCache(t *testing.T) {
icmp := test.createPacket()
icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize,
- Data: buffer.View(icmp).ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.source,
- DstAddr: test.destination,
- })
- ep.HandlePacket(pkt)
+ handleICMPInIPv6(ep, test.source, test.destination, icmp)
// Confirm the endpoint calls the correct NUDHandler method.
if testInterface.probeCount != test.wantProbeCount {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c2e8c3ea7..5cad546b8 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -648,14 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
+ if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -665,14 +661,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // If the packet was generated by the stack (not a raw/packet endpoint
- // where a packet may be written with the header included), then we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = !headerIncluded
- e.handlePacket(pkt)
- }
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -771,14 +763,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
+ if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
n++
continue
}
@@ -852,14 +840,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
- networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
- if err == nil {
- networkEndpoint.(*endpoint).handlePacket(pkt)
+
+ if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil {
+ ep.(*endpoint).handlePacket(pkt)
return nil
}
- if _, ok := err.(*tcpip.ErrBadAddress); !ok {
- return err
- }
r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -896,8 +881,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Loopback traffic skips the prerouting chain.
+ if !e.protocol.parse(pkt) {
+ stats.MalformedPacketsReceived.Increment()
+ return
+ }
+
if !e.nic.IsLoopback() {
+ if e.protocol.stack.HandleLocal() {
+ addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+
+ // The source address is one of our own, so we never should have gotten
+ // a packet like this unless HandleLocal is false or our NIC is the
+ // loopback interface.
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
+
+ // Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
@@ -909,6 +912,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.handlePacket(pkt)
}
+func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
+
+ pkt = pkt.CloneToInbound()
+ if e.protocol.parse(pkt) {
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+ e.handlePacket(pkt)
+ return
+ }
+
+ stats.MalformedPacketsReceived.Increment()
+}
+
// handlePacket is like HandlePacket except it does not perform the prerouting
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
@@ -1798,6 +1816,29 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// parse is like Parse but also attempts to parse the transport layer.
+//
+// Returns true if the network header was successfully parsed.
+func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+ transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
+ if !ok {
+ return false
+ }
+
+ if hasTransportHdr {
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+ }
+
+ return true
+}
+
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 8edaa9508..104fe2139 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -977,12 +977,8 @@ func TestNDPValidation(t *testing.T) {
}
extHdrsLen := extHdrs.Length()
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen,
- Data: payload.ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
- ip.Encode(&header.IPv6Fields{
+ ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen)
+ header.IPv6(ip).Encode(&header.IPv6Fields{
PayloadLength: uint16(len(payload) + extHdrsLen),
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: hopLimit,
@@ -990,7 +986,11 @@ func TestNDPValidation(t *testing.T) {
DstAddr: lladdr0,
ExtensionHeaders: extHdrs,
})
- ep.HandlePacket(pkt)
+ vv := ip.ToVectorisedView()
+ vv.AppendView(payload)
+ ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
}
var tllData [header.NDPLinkLayerAddressSize]byte
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index c24f56ece..0cb9ec3a3 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -75,6 +75,10 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
}
func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
+ if _, _, ok := f.proto.Parse(pkt); !ok {
+ return
+ }
+
netHdr := pkt.NetworkHeader().View()
_, dst := f.proto.ParseAddresses(netHdr)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 41a489047..6f2a0e487 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -777,36 +777,6 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
anyEPs.forEach(deliverPacketEPs)
}
- // Parse headers.
- netProto := n.stack.NetworkProtocolInstance(protocol)
- transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
- if !ok {
- // The packet is too small to contain a network header.
- n.stack.stats.MalformedRcvdPackets.Increment()
- return
- }
- if hasTransportHdr {
- pkt.TransportProtocolNumber = transProtoNum
- // Parse the transport header if present.
- if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
- state.proto.Parse(pkt)
- }
- }
-
- if n.stack.handleLocal && !n.IsLoopback() {
- src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View())
- if r := n.getAddress(protocol, src); r != nil {
- r.DecRef()
-
- // The source address is one of our own, so we never should have gotten a
- // packet like this unless handleLocal is false. Loopback also calls this
- // function even though the packets didn't come from the physical interface
- // so don't drop those.
- n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
- return
- }
- }
-
networkEndpoint.HandlePacket(pkt)
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a51d758d0..035ab33ca 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1319,6 +1319,11 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
return nil
}
+// HandleLocal returns true if non-loopback interfaces are allowed to loop packets.
+func (s *Stack) HandleLocal() bool {
+ return s.handleLocal
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -2063,7 +2068,9 @@ func generateRandInt64() int64 {
}
// FindNetworkEndpoint returns the network endpoint for the given address.
-func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) {
+//
+// Returns nil if the address is not associated with any network endpoint.
+func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2073,9 +2080,9 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
continue
}
addressEndpoint.DecRef()
- return nic.getNetworkEndpoint(netProto), nil
+ return nic.getNetworkEndpoint(netProto)
}
- return nil, &tcpip.ErrBadAddress{}
+ return nil
}
// FindNICNameFromID returns the name of the NIC for the given NICID.
@@ -2103,13 +2110,6 @@ const (
// ParsedOK indicates that a packet was successfully parsed.
ParsedOK ParseResult = iota
- // UnknownNetworkProtocol indicates that the network protocol is unknown.
- UnknownNetworkProtocol
-
- // NetworkLayerParseError indicates that the network packet was not
- // successfully parsed.
- NetworkLayerParseError
-
// UnknownTransportProtocol indicates that the transport protocol is unknown.
UnknownTransportProtocol
@@ -2118,31 +2118,19 @@ const (
TransportLayerParseError
)
-// ParsePacketBuffer parses the provided packet buffer.
-func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return UnknownNetworkProtocol
- }
-
- transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
- if !ok {
- return NetworkLayerParseError
- }
- if !hasTransportHdr {
- return ParsedOK
- }
-
+// ParsePacketBufferTransport parses the provided packet buffer's transport
+// header.
+func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
// TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
// fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
// full explanation.
- if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
return ParsedOK
}
- pkt.TransportProtocolNumber = transProtoNum
+ pkt.TransportProtocolNumber = protocol
// Parse the transport header if present.
- state, ok := s.transportProtocols[transProtoNum]
+ state, ok := s.transportProtocols[protocol]
if !ok {
return UnknownTransportProtocol
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index b641a4aaa..b3386f705 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -119,6 +119,10 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
}
func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ if _, _, ok := f.proto.Parse(pkt); !ok {
+ return
+ }
+
// Increment the received packet count in the protocol descriptor.
netHdr := pkt.NetworkHeader().View()