summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go67
-rw-r--r--pkg/tcpip/tcpip.go28
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go13
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
7 files changed, 58 insertions, 71 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index fdeec12d3..7c7495c30 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -16,6 +16,7 @@
package gonet
import (
+ "bytes"
"context"
"errors"
"io"
@@ -354,8 +355,6 @@ func (c *TCPConn) Write(b []byte) (int, error) {
default:
}
- v := buffer.NewViewFromBytes(b)
-
// We must handle two soft failure conditions simultaneously:
// 1. Write may write nothing and return tcpip.ErrWouldBlock.
// If this happens, we need to register for notifications if we have
@@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) {
// There is no guarantee that all of the condition #1s will occur before
// all of the condition #2s or visa-versa.
var (
- err *tcpip.Error
- nbytes int
- reg bool
- notifyCh chan struct{}
+ r bytes.Reader
+ nbytes int
+ entry waiter.Entry
+ ch <-chan struct{}
)
- for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
- if err == tcpip.ErrWouldBlock {
- if !reg {
- // Only register once.
- reg = true
-
- // Create wait queue entry that notifies a channel.
- var waitEntry waiter.Entry
- waitEntry, notifyCh = waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&waitEntry, waiter.EventOut)
- defer c.wq.EventUnregister(&waitEntry)
+ for nbytes != len(b) {
+ r.Reset(b[nbytes:])
+ n, err := c.ep.Write(&r, tcpip.WriteOptions{})
+ nbytes += int(n)
+ switch err {
+ case nil:
+ case tcpip.ErrWouldBlock:
+ if ch == nil {
+ entry, ch = waiter.NewChannelEntry(nil)
+
+ c.wq.EventRegister(&entry, waiter.EventOut)
+ defer c.wq.EventUnregister(&entry)
} else {
// Don't wait immediately after registration in case more data
// became available between when we last checked and when we setup
@@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) {
select {
case <-deadline:
return nbytes, c.newOpError("write", &timeoutError{})
- case <-notifyCh:
+ case <-ch:
+ continue
}
}
+ default:
+ return nbytes, c.newOpError("write", errors.New(err.String()))
}
-
- var n int64
- n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- nbytes += int(n)
- v.TrimFront(int(n))
- }
-
- if err == nil {
- return nbytes, nil
}
-
- return nbytes, c.newOpError("write", errors.New(err.String()))
+ return nbytes, nil
}
// Close implements net.Conn.Close.
@@ -644,16 +637,18 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// If we're being called by Write, there is no addr
- wopts := tcpip.WriteOptions{}
+ writeOptions := tcpip.WriteOptions{}
if addr != nil {
ua := addr.(*net.UDPAddr)
- wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
+ writeOptions.To = &tcpip.FullAddress{
+ Addr: tcpip.Address(ua.IP),
+ Port: uint16(ua.Port),
+ }
}
- v := buffer.NewView(len(b))
- copy(v, b)
-
- n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
+ var r bytes.Reader
+ r.Reset(b)
+ n, err := c.ep.Write(&r, writeOptions)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -666,7 +661,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
case <-notifyCh:
}
- n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ n, err = c.ep.Write(&r, writeOptions)
if err != tcpip.ErrWouldBlock {
break
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4f59e4ff7..fe01029ad 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -29,6 +29,7 @@
package tcpip
import (
+ "bytes"
"errors"
"fmt"
"io"
@@ -471,30 +472,15 @@ type FullAddress struct {
// This interface allows the endpoint to request the amount of data it needs
// based on internal buffers without exposing them.
type Payloader interface {
- // FullPayload returns all available bytes.
- FullPayload() ([]byte, *Error)
+ io.Reader
- // Payload returns a slice containing at most size bytes.
- Payload(size int) ([]byte, *Error)
+ // Len returns the number of bytes of the unread portion of the
+ // Reader.
+ Len() int
}
-// SlicePayload implements Payloader for slices.
-//
-// This is typically used for tests.
-type SlicePayload []byte
-
-// FullPayload implements Payloader.FullPayload.
-func (s SlicePayload) FullPayload() ([]byte, *Error) {
- return s, nil
-}
-
-// Payload implements Payloader.Payload.
-func (s SlicePayload) Payload(size int) ([]byte, *Error) {
- if size > len(s) {
- size = len(s)
- }
- return s[:size], nil
-}
+var _ Payloader = (*bytes.Buffer)(nil)
+var _ Payloader = (*bytes.Reader)(nil)
var _ io.Writer = (*SliceWriter)(nil)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 256e19296..af00ed548 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -313,11 +313,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
route = r
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
+ var err *tcpip.Error
switch e.NetProto {
case header.IPv4ProtocolNumber:
err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index c0d6fb442..6fd116a98 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -207,7 +207,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
return res, nil
}
-func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, *tcpip.Error) {
// TODO(gvisor.dev/issue/173): Implement.
return 0, tcpip.ErrInvalidOptionValue
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ae743f75e..2dacf5a64 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -280,9 +280,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return 0, tcpip.ErrInvalidEndpointState
}
- payloadBytes, err := p.FullPayload()
- if err != nil {
- return 0, err
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
// If this is an unassociated socket and callee provided a nonzero
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ea509ac73..8d27d43c2 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1534,14 +1534,19 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
}
// Fetch data.
- v, perr := p.Payload(avail)
- if perr != nil || len(v) == 0 {
- // Note that perr may be nil if len(v) == 0.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return 0, nil
+ }
+ v := make([]byte, avail)
+ if _, err := io.ReadFull(p, v); err != nil {
if opts.Atomic {
e.sndBufMu.Unlock()
e.UnlockUser()
}
- return 0, perr
+ return 0, tcpip.ErrBadBuffer
}
if !opts.Atomic {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9f9b3d510..8544fcb08 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -514,9 +514,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return 0, tcpip.ErrBroadcastDisabled
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.