// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package header provides the implementation of the encoding and decoding of // network protocol headers. package header import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { v := initial if odd { v += uint32(buf[0]) buf = buf[1:] } l := len(buf) odd = l&1 != 0 if odd { l-- v += uint32(buf[l]) << 8 } for i := 0; i < l; i += 2 { v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) } return ChecksumCombine(uint16(v), uint16(v>>16)), odd } func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { v := initial if odd { v += uint32(buf[0]) buf = buf[1:] } l := len(buf) odd = l&1 != 0 if odd { l-- v += uint32(buf[l]) << 8 } for (l - 64) >= 0 { i := 0 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) i += 16 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) i += 16 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) i += 16 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) buf = buf[64:] l = l - 64 } if (l - 32) >= 0 { i := 0 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) i += 16 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) buf = buf[32:] l = l - 32 } if (l - 16) >= 0 { i := 0 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) buf = buf[16:] l = l - 16 } if (l - 8) >= 0 { i := 0 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) buf = buf[8:] l = l - 8 } if (l - 4) >= 0 { i := 0 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) buf = buf[4:] l = l - 4 } // At this point since l was even before we started unrolling // there can be only two bytes left to add. if l != 0 { v += (uint32(buf[0]) << 8) + uint32(buf[1]) } return ChecksumCombine(uint16(v), uint16(v>>16)), odd } // ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in // the given byte array. This function uses a non-optimized implementation. Its // only retained for reference and to use as a benchmark/test. Most code should // use the header.Checksum function. // // The initial checksum must have been computed on an even number of bytes. func ChecksumOld(buf []byte, initial uint16) uint16 { s, _ := calculateChecksum(buf, false, uint32(initial)) return s } // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the // given byte array. This function uses an optimized unrolled version of the // checksum algorithm. // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { s, _ := unrolledCalculateChecksum(buf, false, uint32(initial)) return s } // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in // the given VectorizedView. // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { var c Checksumer for _, v := range vv.Views() { c.Add([]byte(v)) } return ChecksumCombine(initial, c.Checksum()) } // Checksumer calculates checksum defined in RFC 1071. type Checksumer struct { sum uint16 odd bool } // Add adds b to checksum. func (c *Checksumer) Add(b []byte) { if len(b) > 0 { c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum)) } } // Checksum returns the latest checksum value. func (c *Checksumer) Checksum() uint16 { return c.sum } // ChecksumCombine combines the two uint16 to form their checksum. This is done // by adding them and the carry. // // Note that checksum a must have been computed on an even number of bytes. func ChecksumCombine(a, b uint16) uint16 { v := uint32(a) + uint32(b) return uint16(v + v>>16) } // PseudoHeaderChecksum calculates the pseudo-header checksum for the given // destination protocol and network address. Pseudo-headers are needed by // transport layers when calculating their own checksum. func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 { xsum := Checksum([]byte(srcAddr), 0) xsum = Checksum([]byte(dstAddr), xsum) // Add the length portion of the checksum to the pseudo-checksum. tmp := make([]byte, 2) binary.BigEndian.PutUint16(tmp, totalLen) xsum = Checksum(tmp, xsum) return Checksum([]byte{0, uint8(protocol)}, xsum) }