diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 24 |
2 files changed, 19 insertions, 11 deletions
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index da88d65d1..d9b5fe6ed 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -262,13 +262,15 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip localAddr := addressEndpoint.AddressWithPrefix().Address addressEndpoint.DecRef() addressEndpoint = nil - igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ + if err := igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.IGMPProtocolNumber, TTL: header.IGMPTTL, TOS: stack.DefaultTOS, }, header.IPv4OptionsSerializer{ &header.IPv4SerializableRouterAlertOption{}, - }) + }); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index cc045c7a9..bb25a76fe 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -237,7 +237,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error { hdrLen := header.IPv4MinimumSize var optLen int if options != nil { @@ -245,19 +245,19 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet } hdrLen += optLen if hdrLen > header.IPv4MaximumHeaderSize { - // Since we have no way to report an error we must either panic or create - // a packet which is different to what was requested. Choose panic as this - // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize)) + return tcpip.ErrMessageTooLong } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) - length := uint16(pkt.Size()) + length := pkt.Size() + if length > math.MaxUint16 { + return tcpip.ErrMessageTooLong + } // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ - TotalLength: length, + TotalLength: uint16(length), ID: uint16(id), TTL: params.TTL, TOS: params.TOS, @@ -268,6 +268,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet }) ip.SetChecksum(^ip.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber + return nil } // handleFragments fragments pkt and calls the handler function on each @@ -295,7 +296,9 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) + if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { + return err + } // iptables filtering. All packets that reach here are locally // generated. @@ -383,7 +386,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) + if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { + return 0, err + } + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) |