diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/header/udp.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 34 |
7 files changed, 105 insertions, 92 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index fef065b05..12c39dfa3 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -53,9 +53,8 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { t.Error("Not a valid IPv4 packet") } - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) + if !ipv4.IsChecksumValid() { + t.Errorf("Bad checksum, got = %d", ipv4.Checksum()) } for _, f := range checkers { @@ -400,18 +399,11 @@ func TCP(checkers ...TransportChecker) NetworkChecker { t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } - // Verify the checksum. tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) + payload := tcp.Payload() + payloadChecksum := header.Checksum(payload, 0) + if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) { + t.Errorf("Bad checksum, got = %d", tcp.Checksum()) } // Run the transport checkers. diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 85bd164cd..2be21ec75 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -455,6 +455,32 @@ func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool { return ipv4LinkLocalMulticastSubnet.Contains(addr) } +// IsChecksumValid returns true iff the IPv4 header's checksum is valid. +func (b IPv4) IsChecksumValid() bool { + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + return b.CalculateChecksum() == 0xffff +} + // IsV4MulticastAddress determines if the provided address is an IPv4 multicast // address (range 224.0.0.0 to 239.255.255.255). The four most significant bits // will be 1110 = 0xe0. diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index adc835d30..0df517000 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -216,104 +216,104 @@ const ( TCPDefaultMSS = 536 ) -// SourcePort returns the "source port" field of the tcp header. +// SourcePort returns the "source port" field of the TCP header. func (b TCP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[TCPSrcPortOffset:]) } -// DestinationPort returns the "destination port" field of the tcp header. +// DestinationPort returns the "destination port" field of the TCP header. func (b TCP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[TCPDstPortOffset:]) } -// SequenceNumber returns the "sequence number" field of the tcp header. +// SequenceNumber returns the "sequence number" field of the TCP header. func (b TCP) SequenceNumber() uint32 { return binary.BigEndian.Uint32(b[TCPSeqNumOffset:]) } -// AckNumber returns the "ack number" field of the tcp header. +// AckNumber returns the "ack number" field of the TCP header. func (b TCP) AckNumber() uint32 { return binary.BigEndian.Uint32(b[TCPAckNumOffset:]) } -// DataOffset returns the "data offset" field of the tcp header. The return +// DataOffset returns the "data offset" field of the TCP header. The return // value is the length of the TCP header in bytes. func (b TCP) DataOffset() uint8 { return (b[TCPDataOffset] >> 4) * 4 } -// Payload returns the data in the tcp packet. +// Payload returns the data in the TCP packet. func (b TCP) Payload() []byte { return b[b.DataOffset():] } -// Flags returns the flags field of the tcp header. +// Flags returns the flags field of the TCP header. func (b TCP) Flags() TCPFlags { return TCPFlags(b[TCPFlagsOffset]) } -// WindowSize returns the "window size" field of the tcp header. +// WindowSize returns the "window size" field of the TCP header. func (b TCP) WindowSize() uint16 { return binary.BigEndian.Uint16(b[TCPWinSizeOffset:]) } -// Checksum returns the "checksum" field of the tcp header. +// Checksum returns the "checksum" field of the TCP header. func (b TCP) Checksum() uint16 { return binary.BigEndian.Uint16(b[TCPChecksumOffset:]) } -// UrgentPointer returns the "urgent pointer" field of the tcp header. +// UrgentPointer returns the "urgent pointer" field of the TCP header. func (b TCP) UrgentPointer() uint16 { return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:]) } -// SetSourcePort sets the "source port" field of the tcp header. +// SetSourcePort sets the "source port" field of the TCP header. func (b TCP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port) } -// SetDestinationPort sets the "destination port" field of the tcp header. +// SetDestinationPort sets the "destination port" field of the TCP header. func (b TCP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port) } -// SetChecksum sets the checksum field of the tcp header. +// SetChecksum sets the checksum field of the TCP header. func (b TCP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum) } -// SetDataOffset sets the data offset field of the tcp header. headerLen should +// SetDataOffset sets the data offset field of the TCP header. headerLen should // be the length of the TCP header in bytes. func (b TCP) SetDataOffset(headerLen uint8) { b[TCPDataOffset] = (headerLen / 4) << 4 } -// SetSequenceNumber sets the sequence number field of the tcp header. +// SetSequenceNumber sets the sequence number field of the TCP header. func (b TCP) SetSequenceNumber(seqNum uint32) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum) } -// SetAckNumber sets the ack number field of the tcp header. +// SetAckNumber sets the ack number field of the TCP header. func (b TCP) SetAckNumber(ackNum uint32) { binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum) } -// SetFlags sets the flags field of the tcp header. +// SetFlags sets the flags field of the TCP header. func (b TCP) SetFlags(flags uint8) { b[TCPFlagsOffset] = flags } -// SetWindowSize sets the window size field of the tcp header. +// SetWindowSize sets the window size field of the TCP header. func (b TCP) SetWindowSize(rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// SetUrgentPoiner sets the window size field of the tcp header. +// SetUrgentPoiner sets the window size field of the TCP header. func (b TCP) SetUrgentPoiner(urgentPointer uint16) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer) } -// CalculateChecksum calculates the checksum of the tcp segment. +// CalculateChecksum calculates the checksum of the TCP segment. // partialChecksum is the checksum of the network-layer pseudo-header // and the checksum of the segment data. func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { @@ -321,6 +321,13 @@ func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { return Checksum(b[:b.DataOffset()], partialChecksum) } +// IsChecksumValid returns true iff the TCP header's checksum is valid. +func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool { + xsum := PseudoHeaderChecksum(TCPProtocolNumber, src, dst, uint16(b.DataOffset())+payloadLength) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + // Options returns a slice that holds the unparsed TCP options in the segment. func (b TCP) Options() []byte { return b[TCPMinimumSize:b.DataOffset()] @@ -340,7 +347,7 @@ func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// Encode encodes all the fields of the tcp header. +// Encode encodes all the fields of the TCP header. func (b TCP) Encode(t *TCPFields) { b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort) @@ -350,7 +357,7 @@ func (b TCP) Encode(t *TCPFields) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer) } -// EncodePartial updates a subset of the fields of the tcp header. It is useful +// EncodePartial updates a subset of the fields of the TCP header. It is useful // in cases when similar segments are produced. func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. @@ -374,7 +381,7 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 } // ParseSynOptions parses the options received in a SYN segment and returns the -// relevant ones. opts should point to the option part of the TCP Header. +// relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { limit := len(opts) diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 98bdd29db..ae9d167ff 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -64,17 +64,17 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. +// SourcePort returns the "source port" field of the UDP header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "destination port" field of the udp header. +// DestinationPort returns the "destination port" field of the UDP header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "length" field of the udp header. +// Length returns the "length" field of the UDP header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } @@ -84,39 +84,46 @@ func (b UDP) Payload() []byte { return b[UDPMinimumSize:] } -// Checksum returns the "checksum" field of the udp header. +// Checksum returns the "checksum" field of the UDP header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the udp header. +// SetSourcePort sets the "source port" field of the UDP header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the udp header. +// SetDestinationPort sets the "destination port" field of the UDP header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the udp header. +// SetChecksum sets the "checksum" field of the UDP header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the udp header. +// SetLength sets the "length" field of the UDP header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// CalculateChecksum calculates the checksum of the udp packet, given the +// CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. return Checksum(b[:UDPMinimumSize], partialChecksum) } -// Encode encodes all the fields of the udp header. +// IsChecksumValid returns true iff the UDP header's checksum is valid. +func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool { + xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length()) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + +// Encode encodes all the fields of the UDP header. func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9a3dc78cb..a82a5790d 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -1178,28 +1178,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) return nil, false } - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { + if !h.IsChecksumValid() { return nil, false } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 8edd6775b..c28641be3 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -236,20 +236,14 @@ func (s *segment) parse(skipChecksumValidation bool) bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - - verifyChecksum := true if skipChecksumValidation { s.csumValid = true - verifyChecksum = false - } - if verifyChecksum { + } else { s.csum = s.hdr.Checksum() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) - xsum = s.hdr.CalculateChecksum(xsum) - xsum = header.ChecksumVV(s.data, xsum) - s.csumValid = xsum == 0xffff + payloadChecksum := header.ChecksumVV(s.data, 0) + payloadLength := uint16(s.data.Size()) + s.csumValid = s.hdr.IsChecksumValid(s.srcAddr, s.dstAddr, payloadChecksum, payloadLength) } - s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) s.ackNumber = seqnum.Value(s.hdr.AckNumber()) s.flags = s.hdr.Flags() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 956da0e0c..f26c7ca10 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1255,20 +1255,29 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // verifyChecksum verifies the checksum unless RX checksum offload is enabled. -// On IPv4, UDP checksum is optional, and a zero value means the transmitter -// omitted the checksum generation (RFC768). -// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { - if !pkt.RXTransportChecksumValidated && - (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { - netHdr := pkt.Network() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) - for _, v := range pkt.Data().Views() { - xsum = header.Checksum(v, xsum) - } - return hdr.CalculateChecksum(xsum) == 0xffff + if pkt.RXTransportChecksumValidated { + return true + } + + // On IPv4, UDP checksum is optional, and a zero value means the transmitter + // omitted the checksum generation, as per RFC 768: + // + // An all zero transmitted checksum value means that the transmitter + // generated no checksum (for debugging or for higher level protocols that + // don't care). + // + // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1: + // + // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP + // checksum is not optional. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 { + return true } - return true + + netHdr := pkt.Network() + payloadChecksum := pkt.Data().AsRange().Checksum() + return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum) } // HandlePacket is called by the stack when new packets arrive to this transport @@ -1284,7 +1293,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } if !verifyChecksum(hdr, pkt) { - // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() return |