summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/dhcp/client.go17
-rw-r--r--pkg/dhcp/server.go15
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go23
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go35
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go2
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go40
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go22
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go16
-rw-r--r--pkg/tcpip/stack/stack.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go12
-rw-r--r--pkg/tcpip/tcpip.go7
-rw-r--r--pkg/tcpip/transport/ping/endpoint.go29
-rw-r--r--pkg/tcpip/transport/tcp/connect.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go16
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go48
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go31
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go18
20 files changed, 220 insertions, 129 deletions
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go
index cf8472c5f..92c634a14 100644
--- a/pkg/dhcp/client.go
+++ b/pkg/dhcp/client.go
@@ -195,10 +195,23 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg
wopts := tcpip.WriteOptions{
To: serverAddr,
}
- if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
+ var resCh <-chan struct{}
+ if _, resCh, err = ep.Write(tcpip.SlicePayload(h), wopts); err != nil && resCh == nil {
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
+ if resCh != nil {
+ select {
+ case <-resCh:
+ case <-ctx.Done():
+ return Config{}, fmt.Errorf("dhcp client address resolution: %v", tcpip.ErrAborted)
+ }
+
+ if _, _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
+ return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
+ }
+ }
+
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
@@ -289,7 +302,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg
reqOpts = append(reqOpts, option{optClientID, clientID})
}
h.setOptions(reqOpts)
- if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
+ if _, _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go
index 003e272b2..26700bdbc 100644
--- a/pkg/dhcp/server.go
+++ b/pkg/dhcp/server.go
@@ -95,9 +95,22 @@ func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
}
func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
- if _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil {
+ _, resCh, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr})
+ if err != nil && resCh == nil {
return fmt.Errorf("write: %v", err)
}
+
+ if resCh != nil {
+ select {
+ case <-resCh:
+ case <-c.ctx.Done():
+ return fmt.Errorf("dhcp server address resolution: %v", tcpip.ErrAborted)
+ }
+
+ if _, _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil {
+ return fmt.Errorf("write: %v", err)
+ }
+ }
return nil
}
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 4d32f7a31..550569b4c 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -276,10 +276,21 @@ func (i *ioSequencePayload) Size() int {
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
f := &ioSequencePayload{ctx: ctx, src: src}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
return int64(n), syserror.ErrWouldBlock
}
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return int64(n), syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+
return int64(n), syserr.TranslateNetstackError(err).ToError()
}
@@ -1016,7 +1027,13 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
EndOfRecord: flags&linux.MSG_EOR != 0,
}
- n, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ if resCh != nil {
+ if err := t.Block(resCh); err != nil {
+ return int(n), syserr.FromError(err)
+ }
+ n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ }
if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
return int(n), syserr.TranslateNetstackError(err)
}
@@ -1030,7 +1047,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
v.TrimFront(int(n))
total := n
for {
- n, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
v.TrimFront(int(n))
total += n
if err != tcpip.ErrWouldBlock {
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 490b9c648..b64dce720 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -393,9 +393,22 @@ func (c *Conn) Write(b []byte) (int, error) {
}
var n uintptr
- n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ var resCh <-chan struct{}
+ n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
nbytes += int(n)
v.TrimFront(int(n))
+
+ if resCh != nil {
+ select {
+ case <-deadline:
+ return nbytes, c.newOpError("write", &timeoutError{})
+ case <-resCh:
+ }
+
+ n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ nbytes += int(n)
+ v.TrimFront(int(n))
+ }
}
if err == nil {
@@ -571,7 +584,16 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
copy(v, b)
wopts := tcpip.WriteOptions{To: &fullAddr}
- n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
+ n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
+ if resCh != nil {
+ select {
+ case <-deadline:
+ return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
+ case <-resCh:
+ }
+
+ n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ }
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -579,15 +601,16 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
c.wq.EventRegister(&waitEntry, waiter.EventOut)
defer c.wq.EventUnregister(&waitEntry)
for {
- n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
- if err != tcpip.ErrWouldBlock {
- break
- }
select {
case <-deadline:
return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
case <-notifyCh:
}
+
+ n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
}
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index a8eef4cf2..b8e53c13e 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -190,7 +190,7 @@ func TestLinkResolution(t *testing.T) {
if ctx.Err() != nil {
break
}
- if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}); err == tcpip.ErrNoLinkAddress {
+ if _, _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}); err == tcpip.ErrNoLinkAddress {
// There's something asynchronous going on; yield to let it do its thing.
runtime.Gosched()
} else if err == nil {
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index d029193fb..c4707736e 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -80,7 +80,7 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) {
v.CapLength(n)
for len(v) > 0 {
- n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
if err != nil {
fmt.Println("Write failed:", err)
return
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 04b8f251a..3a147a75f 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -88,12 +88,14 @@ type linkAddrEntry struct {
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
+ resDone bool
// wakers is a set of waiters for address resolution result. Anytime
// state transitions out of 'incomplete' these waiters are notified.
wakers map[*sleep.Waker]struct{}
cancel chan struct{}
+ resCh chan struct{}
}
func (e *linkAddrEntry) state() entryState {
@@ -182,15 +184,20 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress
// someone waiting for address resolution on it.
entry.changeState(expired)
if entry.cancel != nil {
- entry.cancel <- struct{}{}
+ if !entry.resDone {
+ close(entry.resCh)
+ }
+ close(entry.cancel)
}
*entry = linkAddrEntry{
addr: k,
linkAddr: v,
expiration: time.Now().Add(c.ageLimit),
+ resDone: false,
wakers: make(map[*sleep.Waker]struct{}),
cancel: make(chan struct{}, 1),
+ resCh: make(chan struct{}, 1),
}
c.cache[k] = entry
@@ -202,10 +209,10 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
- return addr, nil
+ return addr, nil, nil
}
}
@@ -214,10 +221,11 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
if entry == nil || entry.state() == expired {
c.mu.Unlock()
if linkRes == nil {
- return "", tcpip.ErrNoLinkAddress
+ return "", nil, tcpip.ErrNoLinkAddress
}
- c.startAddressResolution(k, linkRes, localAddr, linkEP, waker)
- return "", tcpip.ErrWouldBlock
+
+ ch := c.startAddressResolution(k, linkRes, localAddr, linkEP, waker)
+ return "", ch, tcpip.ErrWouldBlock
}
defer c.mu.Unlock()
@@ -227,13 +235,13 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
// in that case it's safe to consider it ready.
fallthrough
case ready:
- return entry.linkAddr, nil
+ return entry.linkAddr, nil, nil
case failed:
- return "", tcpip.ErrNoLinkAddress
+ return "", nil, tcpip.ErrNoLinkAddress
case incomplete:
// Address resolution is still in progress.
entry.addWaker(waker)
- return "", tcpip.ErrWouldBlock
+ return "", entry.resCh, tcpip.ErrWouldBlock
default:
panic(fmt.Sprintf("invalid cache entry state: %d", s))
}
@@ -249,13 +257,13 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) <-chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()
// Look up again with lock held to ensure entry wasn't added by someone else.
if e := c.cache[k]; e != nil && e.state() != expired {
- return
+ return nil
}
// Add 'incomplete' entry in the cache to mark that resolution is in progress.
@@ -274,6 +282,15 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
select {
case <-time.After(c.resolutionTimeout):
if stop := c.checkLinkRequest(k, i); stop {
+ // If entry is evicted then resCh is already closed.
+ c.mu.Lock()
+ if e, ok := c.cache[k]; ok {
+ if !e.resDone {
+ e.resDone = true
+ close(e.resCh)
+ }
+ }
+ c.mu.Unlock()
return
}
case <-cancel:
@@ -281,6 +298,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
}
}
}()
+ return e.resCh
}
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index f0988d6de..e46267f12 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -73,7 +73,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
defer s.Done()
for {
- if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
return got, err
}
s.Fetch(true)
@@ -95,7 +95,7 @@ func TestCacheOverflow(t *testing.T) {
for i := len(testaddrs) - 1; i >= 0; i-- {
e := testaddrs[i]
c.add(e.addr, e.linkAddr)
- got, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
}
@@ -106,7 +106,7 @@ func TestCacheOverflow(t *testing.T) {
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
e := testaddrs[i]
- got, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
}
@@ -117,7 +117,7 @@ func TestCacheOverflow(t *testing.T) {
// The earliest entries should no longer be in the cache.
for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
e := testaddrs[i]
- if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
}
@@ -143,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) {
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testaddrs[len(testaddrs)-1]
- got, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -152,7 +152,7 @@ func TestCacheConcurrent(t *testing.T) {
}
e = testaddrs[0]
- if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
@@ -162,7 +162,7 @@ func TestCacheAgeLimit(t *testing.T) {
e := testaddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
@@ -172,7 +172,7 @@ func TestCacheReplace(t *testing.T) {
e := testaddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
- got, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -181,7 +181,7 @@ func TestCacheReplace(t *testing.T) {
}
c.add(e.addr, l2)
- got, err = c.get(e.addr, nil, "", nil, nil)
+ got, _, err = c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -206,7 +206,7 @@ func TestCacheResolution(t *testing.T) {
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
e := testaddrs[len(testaddrs)-1]
- got, err := c.get(e.addr, linkRes, "", nil, nil)
+ got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -256,7 +256,7 @@ func TestStaticResolution(t *testing.T) {
addr := tcpip.Address("broadcast")
want := tcpip.LinkAddress("mac_broadcast")
- got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
+ got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 595c7e793..0acec2984 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -289,7 +289,11 @@ type LinkAddressCache interface {
// registered with the network protocol, the cache attempts to resolve the address
// and returns ErrWouldBlock. Waker is notified when address resolution is
// complete (success or not).
- GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error)
+ //
+ // If address resolution is required, ErrNoLinkAddress and a notification channel is
+ // returned for the top level caller to block. Channel is closed once address resolution
+ // is complete (success or not).
+ GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
// RemoveWaker removes a waker that has been added in GetLinkAddress().
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index cc9b24e23..6c6400c33 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -89,11 +89,15 @@ func (r *Route) Capabilities() LinkEndpointCapabilities {
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
-func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error {
+//
+// If address resolution is required, ErrNoLinkAddress and a notification channel is
+// returned for the top level caller to block. Channel is closed once address resolution
+// is complete (success or not).
+func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if !r.IsResolutionRequired() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
- return nil
+ return nil, nil
}
nextAddr := r.NextHop
@@ -101,16 +105,16 @@ func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
r.RemoteLinkAddress = r.LocalLinkAddress
- return nil
+ return nil, nil
}
nextAddr = r.RemoteAddress
}
- linkAddr, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
if err != nil {
- return err
+ return ch, err
}
r.RemoteLinkAddress = linkAddr
- return nil
+ return nil, nil
}
// RemoveWaker removes a waker that has been added in Resolve().
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 699519be1..d1ec6a660 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -831,12 +831,12 @@ func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr t
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
nic := s.nics[nicid]
if nic == nil {
s.mu.RUnlock()
- return "", tcpip.ErrUnknownNICID
+ return "", nil, tcpip.ErrUnknownNICID
}
s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 9ec37e7b6..98cc3b120 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -60,21 +60,21 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
- return 0, tcpip.ErrNoRoute
+ return 0, nil, tcpip.ErrNoRoute
}
hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
v, err := p.Get(p.Size())
if err != nil {
- return 0, err
+ return 0, nil, err
}
if err := f.route.WritePacket(hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil {
- return 0, err
+ return 0, nil, err
}
- return uintptr(len(v)), nil
+ return uintptr(len(v)), nil, nil
}
func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
@@ -362,7 +362,7 @@ func TestTransportSend(t *testing.T) {
// Create buffer that will hold the payload.
view := buffer.NewView(30)
- _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
t.Fatalf("write failed: %v", err)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 61272cb05..5f210cdd0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -306,7 +306,12 @@ type Endpoint interface {
//
// Note that unlike io.Writer.Write, it is not an error for Write to
// perform a partial write.
- Write(Payload, WriteOptions) (uintptr, *Error)
+ //
+ // For UDP and Ping sockets if address resolution is required,
+ // ErrNoLinkAddress and a notification channel is returned for the caller to
+ // block. Channel is closed once address resolution is complete (success or
+ // not). The channel is only non-nil in this case.
+ Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go
index fcfb96624..055daa918 100644
--- a/pkg/tcpip/transport/ping/endpoint.go
+++ b/pkg/tcpip/transport/ping/endpoint.go
@@ -198,10 +198,10 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, nil, tcpip.ErrInvalidOptionValue
}
to := opts.To
@@ -211,14 +211,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, tcpip.ErrClosedForSend
+ return 0, nil, tcpip.ErrClosedForSend
}
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
if err != nil {
- return 0, err
+ return 0, nil, err
}
if !retry {
@@ -241,7 +241,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
// Recheck state after lock was re-acquired.
if e.state != stateConnected {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, nil, tcpip.ErrInvalidEndpointState
}
}
} else {
@@ -250,7 +250,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
nicid := to.NIC
if e.bindNICID != 0 {
if nicid != 0 && nicid != e.bindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, nil, tcpip.ErrNoRoute
}
nicid = e.bindNICID
@@ -260,13 +260,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
to = &toCopy
netProto, err := e.checkV4Mapped(to, true)
if err != nil {
- return 0, err
+ return 0, nil, err
}
// Find the enpoint.
r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto)
if err != nil {
- return 0, err
+ return 0, nil, err
}
defer r.Release()
@@ -275,23 +275,20 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
if route.IsResolutionRequired() {
waker := &sleep.Waker{}
- if err := route.Resolve(waker); err != nil {
+ if ch, err := route.Resolve(waker); err != nil {
if err == tcpip.ErrWouldBlock {
// Link address needs to be resolved. Resolution was triggered the
// background. Better luck next time.
- //
- // TODO: queue up the request and send after link address
- // is resolved.
route.RemoveWaker(waker)
- return 0, tcpip.ErrNoLinkAddress
+ return 0, ch, tcpip.ErrNoLinkAddress
}
- return 0, err
+ return 0, nil, err
}
}
v, err := p.Get(p.Size())
if err != nil {
- return 0, err
+ return 0, nil, err
}
switch e.netProto {
@@ -302,7 +299,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
err = sendPing6(route, e.id.LocalPort, v)
}
- return uintptr(len(v)), err
+ return uintptr(len(v)), nil, err
}
// Peek only returns data from a single datagram, so do nothing here.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 68c0d4472..27dbcace2 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -365,7 +365,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
for {
switch index {
case wakerForResolution:
- if err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
// Either success (err == nil) or failure.
return err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index e82e25233..707d6be96 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -492,7 +492,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
@@ -504,15 +504,15 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
if e.state != stateConnected {
switch e.state {
case stateError:
- return 0, e.hardError
+ return 0, nil, e.hardError
default:
- return 0, tcpip.ErrClosedForSend
+ return 0, nil, tcpip.ErrClosedForSend
}
}
// Nothing to do if the buffer is empty.
if p.Size() == 0 {
- return 0, nil
+ return 0, nil, nil
}
e.sndBufMu.Lock()
@@ -520,20 +520,20 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
// Check if the connection has already been closed for sends.
if e.sndClosed {
e.sndBufMu.Unlock()
- return 0, tcpip.ErrClosedForSend
+ return 0, nil, tcpip.ErrClosedForSend
}
// Check against the limit.
avail := e.sndBufSize - e.sndBufUsed
if avail <= 0 {
e.sndBufMu.Unlock()
- return 0, tcpip.ErrWouldBlock
+ return 0, nil, tcpip.ErrWouldBlock
}
v, perr := p.Get(avail)
if perr != nil {
e.sndBufMu.Unlock()
- return 0, perr
+ return 0, nil, perr
}
var err *tcpip.Error
@@ -558,7 +558,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return uintptr(l), err
+ return uintptr(l), nil, err
}
// Peek reads data without consuming it from the endpoint.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index ac21e565b..48852ea47 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -869,7 +869,7 @@ func TestSimpleSend(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -910,7 +910,7 @@ func TestZeroWindowSend(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -971,7 +971,7 @@ func TestScaledWindowConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -1004,7 +1004,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -1077,7 +1077,7 @@ func TestScaledWindowAccept(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -1150,7 +1150,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -1265,7 +1265,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -1653,7 +1653,7 @@ func TestSendOnResetConnection(t *testing.T) {
// Try to write.
view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
}
}
@@ -1763,7 +1763,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -1836,7 +1836,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// any of them.
view := buffer.NewView(10)
for i := tcp.InitialCwnd; i > 0; i-- {
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
}
@@ -1922,7 +1922,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -1948,7 +1948,7 @@ func TestFinWithPendingData(t *testing.T) {
})
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2009,7 +2009,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2046,7 +2046,7 @@ func TestFinWithPartialAck(t *testing.T) {
)
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2116,7 +2116,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2158,7 +2158,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2263,7 +2263,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2371,7 +2371,7 @@ func TestFastRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2503,11 +2503,11 @@ func TestRetransmit(t *testing.T) {
// Write all the data in two shots. Packets will only be written at the
// MTU size though.
half := data[:len(data)/2]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
half = data[len(data)/2:]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -2605,7 +2605,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -3099,7 +3099,7 @@ func TestSelfConnect(t *testing.T) {
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -3290,7 +3290,7 @@ func TestPathMTUDiscovery(t *testing.T) {
data[i] = byte(i)
}
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -3495,7 +3495,7 @@ func TestKeepalive(t *testing.T) {
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 894ead507..ca16fc8fa 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %v", err)
}
@@ -210,7 +210,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %v", err)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f2dd98f35..6ed805357 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -258,10 +258,10 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, nil, tcpip.ErrInvalidOptionValue
}
to := opts.To
@@ -271,14 +271,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, tcpip.ErrClosedForSend
+ return 0, nil, tcpip.ErrClosedForSend
}
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
if err != nil {
- return 0, err
+ return 0, nil, err
}
if !retry {
@@ -303,7 +303,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
// Recheck state after lock was re-acquired.
if e.state != stateConnected {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, nil, tcpip.ErrInvalidEndpointState
}
}
} else {
@@ -312,7 +312,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
nicid := to.NIC
if e.bindNICID != 0 {
if nicid != 0 && nicid != e.bindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, nil, tcpip.ErrNoRoute
}
nicid = e.bindNICID
@@ -322,13 +322,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
to = &toCopy
netProto, err := e.checkV4Mapped(to, false)
if err != nil {
- return 0, err
+ return 0, nil, err
}
// Find the enpoint.
r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, to.Addr, netProto)
if err != nil {
- return 0, err
+ return 0, nil, err
}
defer r.Release()
@@ -338,23 +338,20 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
if route.IsResolutionRequired() {
waker := &sleep.Waker{}
- if err := route.Resolve(waker); err != nil {
+ if ch, err := route.Resolve(waker); err != nil {
if err == tcpip.ErrWouldBlock {
// Link address needs to be resolved. Resolution was triggered the background.
// Better luck next time.
- //
- // TODO: queue up the request and send after link address
- // is resolved.
route.RemoveWaker(waker)
- return 0, tcpip.ErrNoLinkAddress
+ return 0, ch, tcpip.ErrNoLinkAddress
}
- return 0, err
+ return 0, nil, err
}
}
v, err := p.Get(p.Size())
if err != nil {
- return 0, err
+ return 0, nil, err
}
ttl := route.DefaultTTL()
@@ -363,9 +360,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
}
if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
- return 0, err
+ return 0, nil, err
}
- return uintptr(len(v)), nil
+ return uintptr(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 46110c8ff..c3f592bd4 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -482,7 +482,7 @@ func TestV4ReadOnV4(t *testing.T) {
func testV4Write(c *testContext) uint16 {
// Write to V4 mapped address.
payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
})
if err != nil {
@@ -512,7 +512,7 @@ func testV4Write(c *testContext) uint16 {
func testV6Write(c *testContext) uint16 {
// Write to v6 address.
payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
})
if err != nil {
@@ -590,7 +590,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Write to V4 mapped address.
payload := buffer.View(newPayload())
- _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
})
if err != tcpip.ErrNetworkUnreachable {
@@ -613,7 +613,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Write to v6 address.
payload := buffer.View(newPayload())
- _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
})
if err != tcpip.ErrInvalidEndpointState {
@@ -629,7 +629,7 @@ func TestV4WriteOnV6Only(t *testing.T) {
// Write to V4 mapped address.
payload := buffer.View(newPayload())
- _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
})
if err != tcpip.ErrNoRoute {
@@ -650,7 +650,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Write to v6 address.
payload := buffer.View(newPayload())
- _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
})
if err != tcpip.ErrInvalidEndpointState {
@@ -671,7 +671,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Write without destination.
payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
@@ -707,7 +707,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Write without destination.
payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
@@ -856,7 +856,7 @@ func TestTTL(t *testing.T) {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
- n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
+ n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}