diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4/ipv4.go')
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 31 |
1 files changed, 14 insertions, 17 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 354ac1e1d..4b34d0bfb 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -215,21 +215,18 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) { hdrLen := header.IPv4MinimumSize - var opts header.IPv4Options - if params.Options != nil { - var ok bool - if opts, ok = params.Options.(header.IPv4Options); !ok { - panic(fmt.Sprintf("want IPv4Options, got %T", params.Options)) - } - hdrLen += opts.SizeWithPadding() - if hdrLen > header.IPv4MaximumHeaderSize { - // Since we have no way to report an error we must either panic or create - // a packet which is different to what was requested. Choose panic as this - // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize)) - } + var optLen int + if options != nil { + optLen = int(options.Length()) + } + hdrLen += optLen + if hdrLen > header.IPv4MaximumHeaderSize { + // Since we have no way to report an error we must either panic or create + // a packet which is different to what was requested. Choose panic as this + // would be a programming error that should be caught in testing. + panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize)) } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := uint16(pkt.Size()) @@ -245,7 +242,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet Protocol: uint8(params.Protocol), SrcAddr: srcAddr, DstAddr: dstAddr, - Options: opts, + Options: options, }) ip.SetChecksum(^ip.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber @@ -276,7 +273,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) // iptables filtering. All packets that reach here are locally // generated. @@ -364,7 +361,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) |