From 99bf75a6dc1cbad2d688ef80354f45940a40a3a1 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 13 Aug 2019 12:09:48 -0700 Subject: gonet: Replace NewPacketConn with DialUDP. This better matches the standard library and allows creating connected PacketConns. PiperOrigin-RevId: 263187462 --- pkg/tcpip/adapters/gonet/gonet.go | 42 +++++++++++++++++-------- pkg/tcpip/adapters/gonet/gonet_test.go | 56 ++++++++++++++++++++++++++++++---- 2 files changed, 80 insertions(+), 18 deletions(-) diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 308f620e5..f8c7cf3c5 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -556,32 +556,50 @@ type PacketConn struct { wq *waiter.Queue } -// NewPacketConn creates a new PacketConn. -func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { - // Create UDP endpoint and bind it. +// DialUDP creates a new PacketConn. +// +// If laddr is nil, a local address is automatically chosen. +// +// If raddr is nil, the PacketConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) if err != nil { return nil, errors.New(err.String()) } - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "udp", - Addr: fullToUDPAddr(addr), - Err: errors.New(err.String()), + if laddr != nil { + if err := ep.Bind(*laddr); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "udp", + Addr: fullToUDPAddr(*laddr), + Err: errors.New(err.String()), + } } } - c := &PacketConn{ + c := PacketConn{ stack: s, ep: ep, wq: &wq, } c.deadlineTimer.init() - return c, nil + + if raddr != nil { + if err := c.ep.Connect(*raddr); err != nil { + c.ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "udp", + Addr: fullToUDPAddr(*raddr), + Err: errors.New(err.String()), + } + } + } + + return &c, nil } func (c *PacketConn) newOpError(op string, err error) *net.OpError { diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 39efe44c7..5cd208113 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -371,9 +371,9 @@ func TestUDPForwarder(t *testing.T) { }) s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } sent := "abc123" @@ -452,13 +452,13 @@ func TestPacketConnTransfer(t *testing.T) { addr2 := tcpip.FullAddress{NICID, ip2, 11311} s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber) + c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 4):", err) + t.Fatal("DialUDP(bind port 4):", err) } - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } c1.SetDeadline(time.Now().Add(time.Second)) @@ -491,6 +491,50 @@ func TestPacketConnTransfer(t *testing.T) { } } +func TestConnectedPacketConnTransfer(t *testing.T) { + s, e := newLoopbackStack() + if e != nil { + t.Fatalf("newLoopbackStack() = %v", e) + } + + ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr := tcpip.FullAddress{NICID, ip, 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + + c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 4):", err) + } + c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 5):", err) + } + + c1.SetDeadline(time.Now().Add(time.Second)) + c2.SetDeadline(time.Now().Add(time.Second)) + + sent := "abc123" + if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { + t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) + } + recv := make([]byte, len(sent)) + n, err := c1.Read(recv) + if err != nil || n != len(recv) { + t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) + } + + if recv := string(recv); recv != sent { + t.Errorf("got recv = %q, want = %q", recv, sent) + } + + if err := c1.Close(); err != nil { + t.Error("c1.Close():", err) + } + if err := c2.Close(); err != nil { + t.Error("c2.Close():", err) + } +} + func makePipe() (c1, c2 net.Conn, stop func(), err error) { s, e := newLoopbackStack() if e != nil { -- cgit v1.2.3