From 2e4f26ac5ed206e981f662326e19d60b5f05bd4c Mon Sep 17 00:00:00 2001
From: Ghanan Gowripalan <ghanan@google.com>
Date: Thu, 23 Sep 2021 11:42:24 -0700
Subject: Compose ICMP endpoint with datagram-based endpoint

An ICMP endpoint's write path can use the datagram-based endpoint.

Updates #6565.
Test: Datagram-based generic socket + ICMP/ping syscall tests.
PiperOrigin-RevId: 398539844
---
 pkg/tcpip/transport/icmp/BUILD             |   2 +
 pkg/tcpip/transport/icmp/endpoint.go       | 426 ++++++++++-------------------
 pkg/tcpip/transport/icmp/endpoint_state.go |  35 ++-
 pkg/tcpip/transport/internal/network/BUILD |   1 +
 4 files changed, 165 insertions(+), 299 deletions(-)

(limited to 'pkg/tcpip')

diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index bbc0e3ecc..4718ec4ec 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -33,6 +33,8 @@ go_library(
         "//pkg/tcpip/header",
         "//pkg/tcpip/ports",
         "//pkg/tcpip/stack",
+        "//pkg/tcpip/transport",
+        "//pkg/tcpip/transport/internal/network",
         "//pkg/tcpip/transport/raw",
         "//pkg/tcpip/transport/tcp",
         "//pkg/waiter",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 1e519085d..b3436e44c 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,7 @@
 package icmp
 
 import (
+	"fmt"
 	"io"
 	"time"
 
@@ -24,6 +25,8 @@ import (
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/ports"
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
 	"gvisor.dev/gvisor/pkg/waiter"
 )
 
@@ -35,15 +38,6 @@ type icmpPacket struct {
 	receivedAt    time.Time             `state:".(int64)"`
 }
 
-type endpointState int
-
-const (
-	stateInitial endpointState = iota
-	stateBound
-	stateConnected
-	stateClosed
-)
-
 // endpoint represents an ICMP endpoint. This struct serves as the interface
 // between users of the endpoint and the protocol implementation; it is legal to
 // have concurrent goroutines make calls into the endpoint, they are properly
@@ -51,14 +45,18 @@ const (
 //
 // +stateify savable
 type endpoint struct {
-	stack.TransportEndpointInfo
 	tcpip.DefaultSocketOptionsHandler
 
 	// The following fields are initialized at creation time and are
 	// immutable.
 	stack       *stack.Stack `state:"manual"`
+	transProto  tcpip.TransportProtocolNumber
 	waiterQueue *waiter.Queue
 	uniqueID    uint64
+	net         network.Endpoint
+	// TODO(b/142022063): Add ability to save and restore per endpoint stats.
+	stats tcpip.TransportEndpointStats `state:"nosave"`
+	ops   tcpip.SocketOptions
 
 	// The following fields are used to manage the receive queue, and are
 	// protected by rcvMu.
@@ -70,38 +68,23 @@ type endpoint struct {
 
 	// The following fields are protected by the mu mutex.
 	mu sync.RWMutex `state:"nosave"`
-	// shutdownFlags represent the current shutdown state of the endpoint.
-	shutdownFlags tcpip.ShutdownFlags
-	state         endpointState
-	route         *stack.Route `state:"manual"`
-	ttl           uint8
-	stats         tcpip.TransportEndpointStats `state:"nosave"`
-
-	// owner is used to get uid and gid of the packet.
-	owner tcpip.PacketOwner
-
-	// ops is used to get socket level options.
-	ops tcpip.SocketOptions
-
 	// frozen indicates if the packets should be delivered to the endpoint
 	// during restore.
 	frozen bool
+	ident  uint16
 }
 
 func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
 	ep := &endpoint{
-		stack: s,
-		TransportEndpointInfo: stack.TransportEndpointInfo{
-			NetProto:   netProto,
-			TransProto: transProto,
-		},
+		stack:       s,
+		transProto:  transProto,
 		waiterQueue: waiterQueue,
-		state:       stateInitial,
 		uniqueID:    s.UniqueID(),
 	}
 	ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
 	ep.ops.SetSendBufferSize(32*1024, false /* notify */)
 	ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+	ep.net.Init(s, netProto, transProto, &ep.ops)
 
 	// Override with stack defaults.
 	var ss tcpip.SendBufferSizeOption
@@ -128,35 +111,40 @@ func (e *endpoint) Abort() {
 // Close puts the endpoint in a closed state and frees all resources
 // associated with it.
 func (e *endpoint) Close() {
-	e.mu.Lock()
-	e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
-	switch e.state {
-	case stateBound, stateConnected:
-		bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
-		e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
-	}
+	notify := func() bool {
+		e.mu.Lock()
+		defer e.mu.Unlock()
+
+		switch state := e.net.State(); state {
+		case transport.DatagramEndpointStateInitial:
+		case transport.DatagramEndpointStateClosed:
+			return false
+		case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+			info := e.net.Info()
+			info.ID.LocalPort = e.ident
+			e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice()))
+		default:
+			panic(fmt.Sprintf("unhandled state = %s", state))
+		}
 
-	// Close the receive list and drain it.
-	e.rcvMu.Lock()
-	e.rcvClosed = true
-	e.rcvBufSize = 0
-	for !e.rcvList.Empty() {
-		p := e.rcvList.Front()
-		e.rcvList.Remove(p)
-	}
-	e.rcvMu.Unlock()
+		e.net.Shutdown()
+		e.net.Close()
 
-	if e.route != nil {
-		e.route.Release()
-		e.route = nil
-	}
-
-	// Update the state.
-	e.state = stateClosed
+		e.rcvMu.Lock()
+		defer e.rcvMu.Unlock()
+		e.rcvClosed = true
+		e.rcvBufSize = 0
+		for !e.rcvList.Empty() {
+			p := e.rcvList.Front()
+			e.rcvList.Remove(p)
+		}
 
-	e.mu.Unlock()
+		return true
+	}()
 
-	e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+	if notify {
+		e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+	}
 }
 
 // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
@@ -164,7 +152,7 @@ func (*endpoint) ModerateRecvBuf(int) {}
 
 // SetOwner implements tcpip.Endpoint.SetOwner.
 func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
-	e.owner = owner
+	e.net.SetOwner(owner)
 }
 
 // Read implements tcpip.Endpoint.Read.
@@ -214,13 +202,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
 //
 // Returns true for retry if preparation should be retried.
 // +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
-	switch e.state {
-	case stateInitial:
-	case stateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+	switch e.net.State() {
+	case transport.DatagramEndpointStateInitial:
+	case transport.DatagramEndpointStateConnected:
 		return false, nil
-
-	case stateBound:
+	case transport.DatagramEndpointStateBound:
 		if to == nil {
 			return false, &tcpip.ErrDestinationRequired{}
 		}
@@ -235,7 +222,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
 
 	// The state changed when we released the shared locked and re-acquired
 	// it in exclusive mode. Try again.
-	if e.state != stateInitial {
+	if e.net.State() != transport.DatagramEndpointStateInitial {
 		return true, nil
 	}
 
@@ -270,27 +257,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
 	return n, err
 }
 
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
-	// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
-	if opts.More {
-		return 0, &tcpip.ErrInvalidOptionValue{}
-	}
-
-	to := opts.To
-
+func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) {
 	e.mu.RLock()
 	defer e.mu.RUnlock()
 
-	// If we've shutdown with SHUT_WR we are in an invalid state for sending.
-	if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
-		return 0, &tcpip.ErrClosedForSend{}
-	}
-
 	// Prepare for write.
 	for {
-		retry, err := e.prepareForWrite(to)
+		retry, err := e.prepareForWriteInner(opts.To)
 		if err != nil {
-			return 0, err
+			return network.WriteContext{}, 0, err
 		}
 
 		if !retry {
@@ -298,36 +273,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
 		}
 	}
 
-	route := e.route
-	if to != nil {
-		// Reject destination address if it goes through a different
-		// NIC than the endpoint was bound to.
-		nicID := to.NIC
-		if nicID == 0 {
-			nicID = tcpip.NICID(e.ops.GetBindToDevice())
-		}
-		if e.BindNICID != 0 {
-			if nicID != 0 && nicID != e.BindNICID {
-				return 0, &tcpip.ErrNoRoute{}
-			}
-
-			nicID = e.BindNICID
-		}
-
-		dst, netProto, err := e.checkV4MappedLocked(*to)
-		if err != nil {
-			return 0, err
-		}
-
-		// Find the endpoint.
-		r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
-		if err != nil {
-			return 0, err
-		}
-		defer r.Release()
+	ctx, err := e.net.AcquireContextForWrite(opts)
+	return ctx, e.ident, err
+}
 
-		route = r
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+	ctx, ident, err := e.prepareForWrite(opts)
+	if err != nil {
+		return 0, err
 	}
+	defer ctx.Release()
 
 	// TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
 	v := make([]byte, p.Len())
@@ -335,17 +290,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
 		return 0, &tcpip.ErrBadBuffer{}
 	}
 
-	var err tcpip.Error
-	switch e.NetProto {
+	switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto {
 	case header.IPv4ProtocolNumber:
-		err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
+		if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil {
+			return 0, err
+		}
 
 	case header.IPv6ProtocolNumber:
-		err = send6(route, e.ID.LocalPort, v, e.ttl)
-	}
-
-	if err != nil {
-		return 0, err
+		if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil {
+			return 0, err
+		}
+	default:
+		panic(fmt.Sprintf("unhandled network protocol = %d", netProto))
 	}
 
 	return int64(len(v)), nil
@@ -358,24 +314,17 @@ func (e *endpoint) HasNIC(id int32) bool {
 	return e.stack.HasNIC(tcpip.NICID(id))
 }
 
-// SetSockOpt sets a socket option.
-func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
-	return nil
+// SetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+	return e.net.SetSockOpt(opt)
 }
 
-// SetSockOptInt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.
 func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
-	switch opt {
-	case tcpip.TTLOption:
-		e.mu.Lock()
-		e.ttl = uint8(v)
-		e.mu.Unlock()
-
-	}
-	return nil
+	return e.net.SetSockOptInt(opt, v)
 }
 
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
 func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
 	switch opt {
 	case tcpip.ReceiveQueueSizeOption:
@@ -388,31 +337,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
 		e.rcvMu.Unlock()
 		return v, nil
 
-	case tcpip.TTLOption:
-		e.rcvMu.Lock()
-		v := int(e.ttl)
-		e.rcvMu.Unlock()
-		return v, nil
-
 	default:
-		return -1, &tcpip.ErrUnknownProtocolOption{}
+		return e.net.GetSockOptInt(opt)
 	}
 }
 
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
-	return &tcpip.ErrUnknownProtocolOption{}
+// GetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+	return e.net.GetSockOpt(opt)
 }
 
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
+func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error {
 	if len(data) < header.ICMPv4MinimumSize {
 		return &tcpip.ErrInvalidEndpointState{}
 	}
 
 	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
-		ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+		ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength),
 	})
-	pkt.Owner = owner
 
 	icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
 	pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
@@ -427,36 +369,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
 		return &tcpip.ErrInvalidEndpointState{}
 	}
 
-	// Because this icmp endpoint is implemented in the transport layer, we can
-	// only increment the 'stack-wide' stats but we can't increment the
-	// 'per-NetworkEndpoint' stats.
-	sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
-
 	icmpv4.SetChecksum(0)
 	icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
-
 	pkt.Data().AppendView(data)
 
-	if ttl == 0 {
-		ttl = r.DefaultTTL()
-	}
+	// Because this icmp endpoint is implemented in the transport layer, we can
+	// only increment the 'stack-wide' stats but we can't increment the
+	// 'per-NetworkEndpoint' stats.
+	stats := s.Stats().ICMP.V4.PacketsSent
 
-	if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
-		r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+	if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+		stats.Dropped.Increment()
 		return err
 	}
 
-	sentStat.Increment()
+	stats.EchoRequest.Increment()
 	return nil
 }
 
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
+func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error {
 	if len(data) < header.ICMPv6EchoMinimumSize {
 		return &tcpip.ErrInvalidEndpointState{}
 	}
 
 	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
-		ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+		ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength),
 	})
 
 	icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
@@ -469,43 +406,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
 	if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
 		return &tcpip.ErrInvalidEndpointState{}
 	}
-	// Because this icmp endpoint is implemented in the transport layer, we can
-	// only increment the 'stack-wide' stats but we can't increment the
-	// 'per-NetworkEndpoint' stats.
-	sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
 
 	pkt.Data().AppendView(data)
 	dataRange := pkt.Data().AsRange()
 	icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
 		Header:      icmpv6,
-		Src:         r.LocalAddress(),
-		Dst:         r.RemoteAddress(),
+		Src:         src,
+		Dst:         dst,
 		PayloadCsum: dataRange.Checksum(),
 		PayloadLen:  dataRange.Size(),
 	}))
 
-	if ttl == 0 {
-		ttl = r.DefaultTTL()
-	}
+	// Because this icmp endpoint is implemented in the transport layer, we can
+	// only increment the 'stack-wide' stats but we can't increment the
+	// 'per-NetworkEndpoint' stats.
+	stats := s.Stats().ICMP.V6.PacketsSent
 
-	if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
-		r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+	if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+		stats.Dropped.Increment()
+		return err
 	}
 
-	sentStat.Increment()
+	stats.EchoRequest.Increment()
 	return nil
 }
 
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
-	unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
-	if err != nil {
-		return tcpip.FullAddress{}, 0, err
-	}
-	return unwrapped, netProto, nil
-}
-
 // Disconnect implements tcpip.Endpoint.Disconnect.
 func (*endpoint) Disconnect() tcpip.Error {
 	return &tcpip.ErrNotSupported{}
@@ -516,59 +441,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
 	e.mu.Lock()
 	defer e.mu.Unlock()
 
-	nicID := addr.NIC
-	localPort := uint16(0)
-	switch e.state {
-	case stateInitial:
-	case stateBound, stateConnected:
-		localPort = e.ID.LocalPort
-		if e.BindNICID == 0 {
-			break
-		}
+	err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+		nextID.LocalPort = e.ident
 
-		if nicID != 0 && nicID != e.BindNICID {
-			return &tcpip.ErrInvalidEndpointState{}
+		nextID, err := e.registerWithStack(netProto, nextID)
+		if err != nil {
+			return err
 		}
 
-		nicID = e.BindNICID
-	default:
-		return &tcpip.ErrInvalidEndpointState{}
-	}
-
-	addr, netProto, err := e.checkV4MappedLocked(addr)
-	if err != nil {
-		return err
-	}
-
-	// Find a route to the desired destination.
-	r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
-	if err != nil {
-		return err
-	}
-
-	id := stack.TransportEndpointID{
-		LocalAddress:  r.LocalAddress(),
-		LocalPort:     localPort,
-		RemoteAddress: r.RemoteAddress(),
-	}
-
-	// Even if we're connected, this endpoint can still be used to send
-	// packets on a different network protocol, so we register both even if
-	// v6only is set to false and this is an ipv6 endpoint.
-	netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
-	id, err = e.registerWithStack(nicID, netProtos, id)
+		e.ident = nextID.LocalPort
+		return nil
+	})
 	if err != nil {
-		r.Release()
 		return err
 	}
 
-	e.ID = id
-	e.route = r
-	e.RegisterNICID = nicID
-
-	e.state = stateConnected
-
 	e.rcvMu.Lock()
 	e.rcvReady = true
 	e.rcvMu.Unlock()
@@ -586,10 +473,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
 func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
 	e.mu.Lock()
 	defer e.mu.Unlock()
-	e.shutdownFlags |= flags
 
-	if e.state != stateConnected {
+	switch state := e.net.State(); state {
+	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
 		return &tcpip.ErrNotConnected{}
+	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+	default:
+		panic(fmt.Sprintf("unhandled state = %s", state))
+	}
+
+	if flags&tcpip.ShutdownWrite != 0 {
+		if err := e.net.Shutdown(); err != nil {
+			return err
+		}
 	}
 
 	if flags&tcpip.ShutdownRead != 0 {
@@ -616,19 +512,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
 	return nil, nil, &tcpip.ErrNotSupported{}
 }
 
-func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
 	bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
 	if id.LocalPort != 0 {
 		// The endpoint already has a local port, just attempt to
 		// register it.
-		err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
-		return id, err
+		return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
 	}
 
 	// We need to find a port for the endpoint.
 	_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
 		id.LocalPort = p
-		err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
+		err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
 		switch err.(type) {
 		case nil:
 			return true, nil
@@ -645,42 +540,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
 func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
 	// Don't allow binding once endpoint is not in the initial state
 	// anymore.
-	if e.state != stateInitial {
+	if e.net.State() != transport.DatagramEndpointStateInitial {
 		return &tcpip.ErrInvalidEndpointState{}
 	}
 
-	addr, netProto, err := e.checkV4MappedLocked(addr)
-	if err != nil {
-		return err
-	}
-
-	// Expand netProtos to include v4 and v6 if the caller is binding to a
-	// wildcard (empty) address, and this is an IPv6 endpoint with v6only
-	// set to false.
-	netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
-	if len(addr.Addr) != 0 {
-		// A local address was specified, verify that it's valid.
-		if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
-			return &tcpip.ErrBadLocalAddress{}
+	err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+		id := stack.TransportEndpointID{
+			LocalPort:    addr.Port,
+			LocalAddress: addr.Addr,
+		}
+		id, err := e.registerWithStack(boundNetProto, id)
+		if err != nil {
+			return err
 		}
-	}
 
-	id := stack.TransportEndpointID{
-		LocalPort:    addr.Port,
-		LocalAddress: addr.Addr,
-	}
-	id, err = e.registerWithStack(addr.NIC, netProtos, id)
+		e.ident = id.LocalPort
+		return nil
+	})
 	if err != nil {
 		return err
 	}
 
-	e.ID = id
-	e.RegisterNICID = addr.NIC
-
-	// Mark endpoint as bound.
-	e.state = stateBound
-
 	e.rcvMu.Lock()
 	e.rcvReady = true
 	e.rcvMu.Unlock()
@@ -692,7 +572,7 @@ func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address)
 	return addr == header.IPv4Broadcast ||
 		header.IsV4MulticastAddress(addr) ||
 		header.IsV6MulticastAddress(addr) ||
-		e.stack.IsSubnetBroadcast(nicID, e.NetProto, addr)
+		e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr)
 }
 
 // Bind binds the endpoint to a specific local address and port.
@@ -705,15 +585,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
 	e.mu.Lock()
 	defer e.mu.Unlock()
 
-	err := e.bindLocked(addr)
-	if err != nil {
-		return err
-	}
-
-	e.BindNICID = addr.NIC
-	e.BindAddr = addr.Addr
-
-	return nil
+	return e.bindLocked(addr)
 }
 
 // GetLocalAddress returns the address to which the endpoint is bound.
@@ -721,11 +593,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
 	e.mu.RLock()
 	defer e.mu.RUnlock()
 
-	return tcpip.FullAddress{
-		NIC:  e.RegisterNICID,
-		Addr: e.ID.LocalAddress,
-		Port: e.ID.LocalPort,
-	}, nil
+	addr := e.net.GetLocalAddress()
+	addr.Port = e.ident
+	return addr, nil
 }
 
 // GetRemoteAddress returns the address to which the endpoint is connected.
@@ -733,15 +603,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
 	e.mu.RLock()
 	defer e.mu.RUnlock()
 
-	if e.state != stateConnected {
-		return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+	if addr, connected := e.net.GetRemoteAddress(); connected {
+		return addr, nil
 	}
 
-	return tcpip.FullAddress{
-		NIC:  e.RegisterNICID,
-		Addr: e.ID.RemoteAddress,
-		Port: e.ID.RemotePort,
-	}, nil
+	return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
 }
 
 // Readiness returns the current readiness of the endpoint. For example, if
@@ -766,7 +632,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
 // endpoint.
 func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
 	// Only accept echo replies.
-	switch e.NetProto {
+	switch e.net.NetProto() {
 	case header.IPv4ProtocolNumber:
 		h := header.ICMPv4(pkt.TransportHeader().View())
 		if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
@@ -840,9 +706,9 @@ func (e *endpoint) State() uint32 {
 // Info returns a copy of the endpoint info.
 func (e *endpoint) Info() tcpip.EndpointInfo {
 	e.mu.RLock()
-	// Make a copy of the endpoint info.
-	ret := e.TransportEndpointInfo
-	e.mu.RUnlock()
+	defer e.mu.RUnlock()
+	ret := e.net.Info()
+	ret.ID.LocalPort = e.ident
 	return &ret
 }
 
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index b8b839e4a..dfe453ff9 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,13 @@
 package icmp
 
 import (
+	"fmt"
 	"time"
 
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip/buffer"
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport"
 )
 
 // saveReceivedAt is invoked by stateify.
@@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() {
 // Resume implements tcpip.ResumableEndpoint.Resume.
 func (e *endpoint) Resume(s *stack.Stack) {
 	e.thaw()
+
+	e.net.Resume(s)
+
 	e.stack = s
 	e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
 
-	if e.state != stateBound && e.state != stateConnected {
-		return
-	}
-
-	var err tcpip.Error
-	if e.state == stateConnected {
-		e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
+	switch state := e.net.State(); state {
+	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+		var err tcpip.Error
+		info := e.net.Info()
+		info.ID.LocalPort = e.ident
+		info.ID, err = e.registerWithStack(info.NetProto, info.ID)
 		if err != nil {
-			panic(err)
+			panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
 		}
-
-		e.ID.LocalAddress = e.route.LocalAddress()
-	} else if len(e.ID.LocalAddress) != 0 { // stateBound
-		if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
-			panic(&tcpip.ErrBadLocalAddress{})
-		}
-	}
-
-	e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
-	if err != nil {
-		panic(err)
+		e.ident = info.ID.LocalPort
+	default:
+		panic(fmt.Sprintf("unhandled state = %s", state))
 	}
 }
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index b1edce39b..3818cb04e 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -9,6 +9,7 @@ go_library(
         "endpoint_state.go",
     ],
     visibility = [
+        "//pkg/tcpip/transport/icmp:__pkg__",
         "//pkg/tcpip/transport/raw:__pkg__",
         "//pkg/tcpip/transport/udp:__pkg__",
     ],
-- 
cgit v1.2.3