summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4/ipv4.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/ipv4/ipv4.go')
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go114
1 files changed, 70 insertions, 44 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 4592984a5..1bc2c4aff 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -252,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
@@ -270,16 +269,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
netHeader := header.IPv4(pkt.NetworkHeader().View())
ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.HandlePacket(pkt)
+ }
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, pkt)
- loopedR.Release()
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ loopedR := r.MakeLoopedRoute()
+ loopedR.PopulatePacketInfo(pkt)
+ loopedR.Release()
+ e.HandlePacket(pkt)
+ }
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -373,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
+ }
n++
continue
}
@@ -403,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return tcpip.ErrMalformedHeader
}
+
+ hdrLen := header.IPv4(h).HeaderLength()
+ if hdrLen < header.IPv4MinimumSize {
+ return tcpip.ErrMalformedHeader
+ }
+
+ h, ok = pkt.Data.PullUp(int(hdrLen))
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
ip := header.IPv4(h)
// Always set the total length.
@@ -447,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
+ pkt.NICID = e.nic.ID()
+ stats := e.protocol.stack.Stats()
+
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -480,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if h.CalculateChecksum() != 0xffff {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -488,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
- r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
@@ -498,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesInputDropped.Increment()
+ stats.IP.IPTablesInputDropped.Increment()
return
}
@@ -506,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
// The packet is a fragment, let's try to reassemble it.
@@ -520,8 +545,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
@@ -537,12 +562,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
var releaseCB func(bool)
if start == 0 {
pkt := pkt.Clone()
- r := r.Clone()
releaseCB = func(timedOut bool) {
if timedOut {
- _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
}
- r.Release()
}
}
@@ -566,8 +589,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
releaseCB,
)
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
if !ready {
@@ -579,7 +602,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
h.SetFlagsFragmentOffset(0, 0)
}
- r.Stats().IP.PacketsDelivered.Increment()
+ stats.IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
@@ -587,14 +610,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// headers, the setting of the transport number here should be
// unnecessary and removed.
pkt.TransportProtocolNumber = p
- e.handleICMP(r, pkt)
+ e.handleICMP(pkt)
return
}
if len(h.Options()) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
// rather than just throwing them away.
- aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{})
+ aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{})
if err != nil {
switch {
case
@@ -604,15 +627,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
errors.Is(err, errIPv4TimestampOptInvalidLength),
errors.Is(err, errIPv4TimestampOptInvalidPointer),
errors.Is(err, errIPv4TimestampOptOverflow):
- _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
+ stats.MalformedRcvdPackets.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
}
return
}
}
- switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
@@ -620,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC: 1122 Section 3.2.2.1
// A host SHOULD generate Destination Unreachable messages with code:
// 2 (Protocol Unreachable), when the designated transport protocol
// is not supported
- _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -919,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head
originalIPHeaderLength := len(originalIPHeader)
nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
@@ -1172,8 +1196,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// - The location of an error if there was one (or 0 if no error)
// - If there is an error, information as to what it was was.
// - The replacement option set.
-func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
-
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
+ stats := e.protocol.stack.Stats()
opts := header.IPv4Options(orig)
optIter := opts.MakeIterator()
@@ -1186,13 +1210,15 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
// This will need tweaking when we start really forwarding packets
// as we may need to get two addresses, for rx and tx interfaces.
// We will also have to take usage into account.
- prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber)
+ prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
localAddress := prefixedAddress.Address
if err != nil {
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ dstAddr := h.DestinationAddress()
+ if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
}
- localAddress = r.LocalAddress
+ localAddress = dstAddr
}
for {
@@ -1219,9 +1245,9 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
optLen := int(option.Size())
switch option := option.(type) {
case *header.IPv4OptionTimestamp:
- r.Stats().IP.OptionTSReceived.Increment()
+ stats.IP.OptionTSReceived.Increment()
if usage.actions().timestamp != optionRemove {
- clock := r.Stack().Clock()
+ clock := e.protocol.stack.Clock()
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
@@ -1232,7 +1258,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
}
case *header.IPv4OptionRecordRoute:
- r.Stats().IP.OptionRRReceived.Increment()
+ stats.IP.OptionRRReceived.Increment()
if usage.actions().recordRoute != optionRemove {
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
@@ -1244,7 +1270,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
}
default:
- r.Stats().IP.OptionUnknownReceived.Increment()
+ stats.IP.OptionUnknownReceived.Increment()
if usage.actions().unknown == optionPass {
newBuffer := optIter.RemainingBuffer()[:optLen]
// Arguments already heavily checked.. ignore result.