summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorTing-Yu Wang <anivia@google.com>2021-01-07 14:14:58 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-07 14:17:18 -0800
commitb1de1da318631c6d29f6c04dea370f712078f443 (patch)
treeb4e7f8f1b8fd195fa5d16257c5687126e1c7c9f6 /pkg/tcpip/transport/udp
parentf4b4ed666d13eef6aebe23189b1431a933de0d8e (diff)
netstack: Refactor tcpip.Endpoint.Read
Read now takes a destination io.Writer, count, options. Keeping the method name Read, in contrast to the Write method. This enables: * direct transfer of views under VV * zero copy It also eliminates the need for sentry to keep a slice of view because userspace had requested a read that is smaller than the view returned, removing the complexity there. Read/Peek/ReadPacket are now consolidated together and some duplicate code is removed. PiperOrigin-RevId: 350636322
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go31
3 files changed, 48 insertions, 27 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 7ebae63d8..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,5 +58,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4e8bd8b04..075de1db0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,6 +16,7 @@ package udp
import (
"fmt"
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
@@ -282,11 +283,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
e.rcvMu.Lock()
@@ -298,18 +298,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
- e.rcvMu.Unlock()
-
- if addr != nil {
- *addr = p.senderAddress
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
}
+ e.rcvMu.Unlock()
+ // Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
Timestamp: p.timestamp,
@@ -331,7 +330,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
}
- return p.data.ToView(), cm, nil
+
+ // Read Result
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: cm,
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
+ }
+
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -566,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 8429f34b4..455b8c2aa 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -18,10 +18,12 @@ import (
"bytes"
"context"
"fmt"
+ "io/ioutil"
"math/rand"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -595,13 +597,13 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- var addr tcpip.FullAddress
- v, cm, err := c.ep.Read(&addr)
+ var buf bytes.Buffer
+ res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, cm, err = c.ep.Read(&addr)
+ res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -621,23 +623,32 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
}
if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
}
- // Check the peer address.
+ // Check the read result.
h := flow.header4Tuple(incoming)
- if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages", // ControlMessages will be checked later.
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
}
// Check the payload.
+ v := buf.Bytes()
if !bytes.Equal(payload, v) {
c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
for _, f := range checkers {
- f(c.t, cm)
+ f(c.t, res.ControlMessages)
}
c.checkEndpointReadStats(1, epstats, err)
@@ -828,8 +839,8 @@ func TestV4ReadSelfSource(t *testing.T) {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
- if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
+ if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr {
+ t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
}
})
}