summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go17
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go16
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go45
3 files changed, 40 insertions, 38 deletions
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 9cb81245a..770f56c3d 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -72,7 +72,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
+
+ // Only send a reply if the checksum is valid.
+ wantChecksum := h.Checksum()
+ // Reset the checksum field to 0 to can calculate the proper
+ // checksum. We'll have to reset this before we hand the packet
+ // off.
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */)
+ if gotChecksum != wantChecksum {
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ received.Invalid.Increment()
+ return
+ }
+
// It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
vv := vv.Clone(nil)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 00840cfcf..cc384dd3d 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -661,6 +661,22 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ // Only accept echo replies.
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ h := header.ICMPv4(vv.First())
+ if h.Type() != header.ICMPv4EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ case header.IPv6ProtocolNumber:
+ h := header.ICMPv6(vv.First())
+ if h.Type() != header.ICMPv6EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ }
+
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7004c7ff4..1a16a3607 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -80,11 +80,9 @@ type endpoint struct {
// The following fields are protected by mu.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
- closed bool
- connected bool
- bound bool
+ closed bool
+ connected bool
+ bound bool
// registeredNIC is the NIC to which th endpoint is explicitly
// registered. Is set when Connect or Bind are used to specify a NIC.
registeredNIC tcpip.NICID
@@ -192,12 +190,6 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
return 0, nil, tcpip.ErrInvalidEndpointState
}
- // Check whether we've shutdown writing.
- if ep.shutdownFlags&tcpip.ShutdownWrite != 0 {
- ep.mu.RUnlock()
- return 0, nil, tcpip.ErrClosedForSend
- }
-
// Did the user caller provide a destination? If not, use the connected
// destination.
if opts.To == nil {
@@ -205,7 +197,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
// connected to another address.
if !ep.connected {
ep.mu.RUnlock()
- return 0, nil, tcpip.ErrNotConnected
+ return 0, nil, tcpip.ErrDestinationRequired
}
if ep.route.IsResolutionRequired() {
@@ -355,7 +347,7 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return nil
}
-// Shutdown implements tcpip.Endpoint.Shutdown.
+// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
@@ -363,20 +355,6 @@ func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
if !ep.connected {
return tcpip.ErrNotConnected
}
-
- ep.shutdownFlags |= flags
-
- if flags&tcpip.ShutdownRead != 0 {
- ep.rcvMu.Lock()
- wasClosed := ep.rcvClosed
- ep.rcvClosed = true
- ep.rcvMu.Unlock()
-
- if !wasClosed {
- ep.waiterQueue.Notify(waiter.EventIn)
- }
- }
-
return nil
}
@@ -427,17 +405,8 @@ func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
-
- if !ep.connected {
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
- }
-
- return tcpip.FullAddress{
- NIC: ep.registeredNIC,
- Addr: ep.route.RemoteAddress,
- }, nil
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
// Readiness implements tcpip.Endpoint.Readiness.