diff options
Diffstat (limited to 'pkg/tcpip/network/ipv6/ipv6.go')
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 54 |
1 files changed, 32 insertions, 22 deletions
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e506e99e9..a49b5ac77 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "fmt" "hash/fnv" + "math" "sort" "sync/atomic" "time" @@ -431,19 +432,27 @@ func (e *endpoint) MTU() uint32 { // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { + // TODO(gvisor.dev/issues/5035): The maximum header length returned here does + // not open the possibility for the caller to know about size required for + // extension headers. return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) { + extHdrsLen := extensionHeaders.Length() + length := pkt.Size() + extensionHeaders.Length() + if length > math.MaxUint16 { + panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16)) + } + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(params.Protocol), - HopLimit: params.TTL, - TrafficClass: params.TOS, - SrcAddr: srcAddr, - DstAddr: dstAddr, + PayloadLength: uint16(length), + TransportProtocol: params.Protocol, + HopLimit: params.TTL, + TrafficClass: params.TOS, + SrcAddr: srcAddr, + DstAddr: dstAddr, + ExtensionHeaders: extensionHeaders, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -498,7 +507,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */) // iptables filtering. All packets that reach here are locally // generated. @@ -587,7 +596,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -1793,24 +1802,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea fragPkt.NetworkProtocolNumber = ProtocolNumber originalIPHeadersLength := len(originalIPHeaders) - fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + + s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + M: more, + Identification: id, + }} + + fragmentIPHeadersLength := originalIPHeadersLength + s.Length() fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) - fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) } - fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) - fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) - fragmentHeader.Encode(&header.IPv6FragmentFields{ - M: more, - FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), - Identification: id, - NextHeader: uint8(transportProto), - }) + nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:]) + + fragmentIPHeaders.SetNextHeader(nextHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) return fragPkt, more } |