From 23574b1b87ce5aed7b78a53663eac61ae030e9d5 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 14 Nov 2019 22:54:01 -0800 Subject: Fix panic when logging raw packets via sniffer. Sniffer assumed that outgoing packets have transport headers, but users can write packets via SOCK_RAW with arbitrary transport headers that netstack doesn't know about. We now explicitly check for the presence of network and transport headers before assuming they exist. PiperOrigin-RevId: 280594395 --- pkg/tcpip/link/sniffer/sniffer.go | 277 +++++++++++++++++++------------------- 1 file changed, 140 insertions(+), 137 deletions(-) (limited to 'pkg/tcpip/link') diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 122680e10..147d4e242 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -118,7 +118,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, pkt.Data.First(), nil) + logPacket("recv", protocol, pkt, nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { vs := pkt.Data.Views() @@ -195,7 +195,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, pkt.Header.View(), gso) + logPacket("send", protocol, pkt, gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { hdrBuf := pkt.Header.View() @@ -247,7 +247,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) + logPacket("send raw packet", 0, tcpip.PacketBuffer{}, nil /* gso */) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { length := vv.Size() @@ -289,7 +289,7 @@ func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) // Wait implements stack.LinkEndpoint.Wait. func (*endpoint) Wait() {} -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -298,39 +298,40 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool - switch protocol { - case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) - fragmentOffset = ipv4.FragmentOffset() - moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments - src = ipv4.SourceAddress() - dst = ipv4.DestinationAddress() - transProto = ipv4.Protocol() - size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] - id = int(ipv4.ID()) - - case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) - src = ipv6.SourceAddress() - dst = ipv6.DestinationAddress() - transProto = ipv6.NextHeader() - size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] - - case header.ARPProtocolNumber: - arp := header.ARP(b) - log.Infof( - "%s arp %v (%v) -> %v (%v) valid:%v", - prefix, - tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), - tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), - arp.IsValid(), - ) - return - default: - log.Infof("%s unknown network protocol", prefix) - return + + if pkt.NetworkHeader != nil { + switch protocol { + case header.IPv4ProtocolNumber: + ipv4 := header.IPv4(pkt.NetworkHeader) + fragmentOffset = ipv4.FragmentOffset() + moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments + src = ipv4.SourceAddress() + dst = ipv4.DestinationAddress() + transProto = ipv4.Protocol() + size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) + id = int(ipv4.ID()) + + case header.IPv6ProtocolNumber: + ipv6 := header.IPv6(pkt.NetworkHeader) + src = ipv6.SourceAddress() + dst = ipv6.DestinationAddress() + transProto = ipv6.NextHeader() + size = ipv6.PayloadLength() + + case header.ARPProtocolNumber: + arp := header.ARP(pkt.NetworkHeader) + log.Infof( + "%s arp %v (%v) -> %v (%v) valid:%v", + prefix, + tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), + tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), + arp.IsValid(), + ) + return + default: + log.Infof("%s unknown network protocol", prefix) + return + } } // Figure out the transport layer info. @@ -338,118 +339,120 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie srcPort := uint16(0) dstPort := uint16(0) details := "" - switch tcpip.TransportProtocolNumber(transProto) { - case header.ICMPv4ProtocolNumber: - transName = "icmp" - icmp := header.ICMPv4(b) - icmpType := "unknown" - if fragmentOffset == 0 { + if pkt.TransportHeader != nil { + switch tcpip.TransportProtocolNumber(transProto) { + case header.ICMPv4ProtocolNumber: + transName = "icmp" + icmp := header.ICMPv4(pkt.TransportHeader) + icmpType := "unknown" + if fragmentOffset == 0 { + switch icmp.Type() { + case header.ICMPv4EchoReply: + icmpType = "echo reply" + case header.ICMPv4DstUnreachable: + icmpType = "destination unreachable" + case header.ICMPv4SrcQuench: + icmpType = "source quench" + case header.ICMPv4Redirect: + icmpType = "redirect" + case header.ICMPv4Echo: + icmpType = "echo" + case header.ICMPv4TimeExceeded: + icmpType = "time exceeded" + case header.ICMPv4ParamProblem: + icmpType = "param problem" + case header.ICMPv4Timestamp: + icmpType = "timestamp" + case header.ICMPv4TimestampReply: + icmpType = "timestamp reply" + case header.ICMPv4InfoRequest: + icmpType = "info request" + case header.ICMPv4InfoReply: + icmpType = "info reply" + } + } + log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + return + + case header.ICMPv6ProtocolNumber: + transName = "icmp" + icmp := header.ICMPv6(pkt.TransportHeader) + icmpType := "unknown" switch icmp.Type() { - case header.ICMPv4EchoReply: - icmpType = "echo reply" - case header.ICMPv4DstUnreachable: + case header.ICMPv6DstUnreachable: icmpType = "destination unreachable" - case header.ICMPv4SrcQuench: - icmpType = "source quench" - case header.ICMPv4Redirect: - icmpType = "redirect" - case header.ICMPv4Echo: - icmpType = "echo" - case header.ICMPv4TimeExceeded: + case header.ICMPv6PacketTooBig: + icmpType = "packet too big" + case header.ICMPv6TimeExceeded: icmpType = "time exceeded" - case header.ICMPv4ParamProblem: + case header.ICMPv6ParamProblem: icmpType = "param problem" - case header.ICMPv4Timestamp: - icmpType = "timestamp" - case header.ICMPv4TimestampReply: - icmpType = "timestamp reply" - case header.ICMPv4InfoRequest: - icmpType = "info request" - case header.ICMPv4InfoReply: - icmpType = "info reply" + case header.ICMPv6EchoRequest: + icmpType = "echo request" + case header.ICMPv6EchoReply: + icmpType = "echo reply" + case header.ICMPv6RouterSolicit: + icmpType = "router solicit" + case header.ICMPv6RouterAdvert: + icmpType = "router advert" + case header.ICMPv6NeighborSolicit: + icmpType = "neighbor solicit" + case header.ICMPv6NeighborAdvert: + icmpType = "neighbor advert" + case header.ICMPv6RedirectMsg: + icmpType = "redirect message" } - } - log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) - return - - case header.ICMPv6ProtocolNumber: - transName = "icmp" - icmp := header.ICMPv6(b) - icmpType := "unknown" - switch icmp.Type() { - case header.ICMPv6DstUnreachable: - icmpType = "destination unreachable" - case header.ICMPv6PacketTooBig: - icmpType = "packet too big" - case header.ICMPv6TimeExceeded: - icmpType = "time exceeded" - case header.ICMPv6ParamProblem: - icmpType = "param problem" - case header.ICMPv6EchoRequest: - icmpType = "echo request" - case header.ICMPv6EchoReply: - icmpType = "echo reply" - case header.ICMPv6RouterSolicit: - icmpType = "router solicit" - case header.ICMPv6RouterAdvert: - icmpType = "router advert" - case header.ICMPv6NeighborSolicit: - icmpType = "neighbor solicit" - case header.ICMPv6NeighborAdvert: - icmpType = "neighbor advert" - case header.ICMPv6RedirectMsg: - icmpType = "redirect message" - } - log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) - return - - case header.UDPProtocolNumber: - transName = "udp" - udp := header.UDP(b) - if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { - srcPort = udp.SourcePort() - dstPort = udp.DestinationPort() - details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) - size -= header.UDPMinimumSize - } + log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + return - case header.TCPProtocolNumber: - transName = "tcp" - tcp := header.TCP(b) - if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { - offset := int(tcp.DataOffset()) - if offset < header.TCPMinimumSize { - details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) - break - } - if offset > len(tcp) && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) - break + case header.UDPProtocolNumber: + transName = "udp" + udp := header.UDP(pkt.TransportHeader) + if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { + srcPort = udp.SourcePort() + dstPort = udp.DestinationPort() + details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) + size -= header.UDPMinimumSize } - srcPort = tcp.SourcePort() - dstPort = tcp.DestinationPort() - size -= uint16(offset) + case header.TCPProtocolNumber: + transName = "tcp" + tcp := header.TCP(pkt.TransportHeader) + if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { + offset := int(tcp.DataOffset()) + if offset < header.TCPMinimumSize { + details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) + break + } + if offset > len(tcp) && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) + break + } - // Initialize the TCP flags. - flags := tcp.Flags() - flagsStr := []byte("FSRPAU") - for i := range flagsStr { - if flags&(1< %v unknown transport protocol: %d", prefix, src, dst, transProto) - return + default: + log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto) + return + } } if gso != nil { -- cgit v1.2.3