summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go42
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go56
2 files changed, 80 insertions, 18 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 308f620e5..f8c7cf3c5 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -556,32 +556,50 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
- // Create UDP endpoint and bind it.
+// DialUDP creates a new PacketConn.
+//
+// If laddr is nil, a local address is automatically chosen.
+//
+// If raddr is nil, the PacketConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
- if err := ep.Bind(addr); err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "bind",
- Net: "udp",
- Addr: fullToUDPAddr(addr),
- Err: errors.New(err.String()),
+ if laddr != nil {
+ if err := ep.Bind(*laddr); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "udp",
+ Addr: fullToUDPAddr(*laddr),
+ Err: errors.New(err.String()),
+ }
}
}
- c := &PacketConn{
+ c := PacketConn{
stack: s,
ep: ep,
wq: &wq,
}
c.deadlineTimer.init()
- return c, nil
+
+ if raddr != nil {
+ if err := c.ep.Connect(*raddr); err != nil {
+ c.ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "udp",
+ Addr: fullToUDPAddr(*raddr),
+ Err: errors.New(err.String()),
+ }
+ }
+ }
+
+ return &c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 39efe44c7..5cd208113 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -371,9 +371,9 @@ func TestUDPForwarder(t *testing.T) {
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
sent := "abc123"
@@ -452,13 +452,13 @@ func TestPacketConnTransfer(t *testing.T) {
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
- c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber)
+ c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 4):", err)
+ t.Fatal("DialUDP(bind port 4):", err)
}
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
c1.SetDeadline(time.Now().Add(time.Second))
@@ -491,6 +491,50 @@ func TestPacketConnTransfer(t *testing.T) {
}
}
+func TestConnectedPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 4):", err)
+ }
+ c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, err := c1.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
func makePipe() (c1, c2 net.Conn, stop func(), err error) {
s, e := newLoopbackStack()
if e != nil {