diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 45 |
3 files changed, 40 insertions, 38 deletions
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 9cb81245a..770f56c3d 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -72,7 +72,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } + + // Only send a reply if the checksum is valid. + wantChecksum := h.Checksum() + // Reset the checksum field to 0 to can calculate the proper + // checksum. We'll have to reset this before we hand the packet + // off. + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + if gotChecksum != wantChecksum { + // It's possible that a raw socket expects to receive this. + h.SetChecksum(wantChecksum) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + received.Invalid.Increment() + return + } + // It's possible that a raw socket expects to receive this. + h.SetChecksum(wantChecksum) e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) vv := vv.Clone(nil) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 00840cfcf..cc384dd3d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -661,6 +661,22 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + // Only accept echo replies. + switch e.netProto { + case header.IPv4ProtocolNumber: + h := header.ICMPv4(vv.First()) + if h.Type() != header.ICMPv4EchoReply { + e.stack.Stats().DroppedPackets.Increment() + return + } + case header.IPv6ProtocolNumber: + h := header.ICMPv6(vv.First()) + if h.Type() != header.ICMPv6EchoReply { + e.stack.Stats().DroppedPackets.Increment() + return + } + } + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 7004c7ff4..1a16a3607 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -80,11 +80,9 @@ type endpoint struct { // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` sndBufSize int - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - closed bool - connected bool - bound bool + closed bool + connected bool + bound bool // registeredNIC is the NIC to which th endpoint is explicitly // registered. Is set when Connect or Bind are used to specify a NIC. registeredNIC tcpip.NICID @@ -192,12 +190,6 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, tcpip.ErrInvalidEndpointState } - // Check whether we've shutdown writing. - if ep.shutdownFlags&tcpip.ShutdownWrite != 0 { - ep.mu.RUnlock() - return 0, nil, tcpip.ErrClosedForSend - } - // Did the user caller provide a destination? If not, use the connected // destination. if opts.To == nil { @@ -205,7 +197,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp // connected to another address. if !ep.connected { ep.mu.RUnlock() - return 0, nil, tcpip.ErrNotConnected + return 0, nil, tcpip.ErrDestinationRequired } if ep.route.IsResolutionRequired() { @@ -355,7 +347,7 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } -// Shutdown implements tcpip.Endpoint.Shutdown. +// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() @@ -363,20 +355,6 @@ func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if !ep.connected { return tcpip.ErrNotConnected } - - ep.shutdownFlags |= flags - - if flags&tcpip.ShutdownRead != 0 { - ep.rcvMu.Lock() - wasClosed := ep.rcvClosed - ep.rcvClosed = true - ep.rcvMu.Unlock() - - if !wasClosed { - ep.waiterQueue.Notify(waiter.EventIn) - } - } - return nil } @@ -427,17 +405,8 @@ func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - ep.mu.RLock() - defer ep.mu.RUnlock() - - if !ep.connected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected - } - - return tcpip.FullAddress{ - NIC: ep.registeredNIC, - Addr: ep.route.RemoteAddress, - }, nil + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected } // Readiness implements tcpip.Endpoint.Readiness. |