diff options
Diffstat (limited to 'conn')
-rw-r--r-- | conn/bind_linux.go | 47 | ||||
-rw-r--r-- | conn/bind_std.go | 37 | ||||
-rw-r--r-- | conn/bind_windows.go | 67 | ||||
-rw-r--r-- | conn/bindtest/bindtest.go | 39 | ||||
-rw-r--r-- | conn/conn.go | 24 | ||||
-rw-r--r-- | conn/conn_test.go | 24 |
6 files changed, 171 insertions, 67 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go index bd710ae..b6bc0dc 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -193,6 +193,10 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error { return nil } +func (bind *LinuxSocketBind) BatchSize() int { + return 1 +} + func (bind *LinuxSocketBind) Close() error { // Take a readlock to shut down the sockets... bind.mu.RLock() @@ -223,29 +227,39 @@ func (bind *LinuxSocketBind) Close() error { return err2 } -func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) { +func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() if bind.sock4 == -1 { - return 0, nil, net.ErrClosed + return 0, net.ErrClosed } var end LinuxSocketEndpoint - n, err := receive4(bind.sock4, buf, &end) - return n, &end, err + n, err := receive4(bind.sock4, buffs[0], &end) + if err != nil { + return 0, err + } + eps[0] = &end + sizes[0] = n + return 1, nil } -func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) { +func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() if bind.sock6 == -1 { - return 0, nil, net.ErrClosed + return 0, net.ErrClosed } var end LinuxSocketEndpoint - n, err := receive6(bind.sock6, buf, &end) - return n, &end, err + n, err := receive6(bind.sock6, buffs[0], &end) + if err != nil { + return 0, err + } + eps[0] = &end + sizes[0] = n + return 1, nil } -func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { +func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error { nend, ok := end.(*LinuxSocketEndpoint) if !ok { return ErrWrongEndpointType @@ -256,13 +270,24 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { if bind.sock4 == -1 { return net.ErrClosed } - return send4(bind.sock4, nend, buff) + for _, buff := range buffs { + err := send4(bind.sock4, nend, buff) + if err != nil { + return err + } + } } else { if bind.sock6 == -1 { return net.ErrClosed } - return send6(bind.sock6, nend, buff) + for _, buff := range buffs { + err := send6(bind.sock6, nend, buff) + if err != nil { + return err + } + } } + return nil } func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { diff --git a/conn/bind_std.go b/conn/bind_std.go index ae07aac..98fe23c 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -128,6 +128,10 @@ again: return fns, uint16(port), nil } +func (bind *StdNetBind) BatchSize() int { + return 1 +} + func (bind *StdNetBind) Close() error { bind.mu.Lock() defer bind.mu.Unlock() @@ -150,20 +154,30 @@ func (bind *StdNetBind) Close() error { } func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { - return func(buff []byte) (int, Endpoint, error) { - n, endpoint, err := conn.ReadFromUDPAddrPort(buff) - return n, asEndpoint(endpoint), err + return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) + if err == nil { + sizes[0] = size + eps[0] = asEndpoint(endpoint) + return 1, nil + } + return 0, err } } func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { - return func(buff []byte) (int, Endpoint, error) { - n, endpoint, err := conn.ReadFromUDPAddrPort(buff) - return n, asEndpoint(endpoint), err + return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) + if err == nil { + sizes[0] = size + eps[0] = asEndpoint(endpoint) + return 1, nil + } + return 0, err } } -func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { +func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { var err error nend, ok := endpoint.(StdNetEndpoint) if !ok { @@ -186,8 +200,13 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } - _, err = conn.WriteToUDPAddrPort(buff, addrPort) - return err + for _, buff := range buffs { + _, err = conn.WriteToUDPAddrPort(buff, addrPort) + if err != nil { + return err + } + } + return nil } // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. diff --git a/conn/bind_windows.go b/conn/bind_windows.go index f8b187b..5a0b8c2 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -321,6 +321,11 @@ func (bind *WinRingBind) Close() error { return nil } +func (bind *WinRingBind) BatchSize() int { + // TODO: implement batching in and out of the ring + return 1 +} + func (bind *WinRingBind) SetMark(mark uint32) error { return nil } @@ -409,16 +414,22 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v4.Receive(buf, &bind.isOpen) + n, ep, err := bind.v4.Receive(buffs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } -func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v6.Receive(buf, &bind.isOpen) + n, ep, err := bind.v6.Receive(buffs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { @@ -473,32 +484,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(buffs [][]byte, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType } bind.mu.RLock() defer bind.mu.RUnlock() - switch nend.family { - case windows.AF_INET: - if bind.v4.blackhole { - return nil - } - return bind.v4.Send(buf, nend, &bind.isOpen) - case windows.AF_INET6: - if bind.v6.blackhole { - return nil + for _, buf := range buffs { + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + continue + } + if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + case windows.AF_INET6: + if bind.v6.blackhole { + continue + } + if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { + return err + } } - return bind.v6.Send(buf, nend, &bind.isOpen) } return nil } -func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv4.SyscallConn() +func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -511,14 +528,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole4 = blackhole + s.blackhole4 = blackhole return nil } -func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv6.SyscallConn() +func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv6.SyscallConn() if err != nil { return err } @@ -531,7 +548,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole6 = blackhole + s.blackhole6 = blackhole return nil } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 9605a2a..b33c53d 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -89,32 +89,39 @@ func (c *ChannelBind) Close() error { return nil } +func (c *ChannelBind) BatchSize() int { return 1 } + func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(b []byte) (n int, ep conn.Endpoint, err error) { + return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: - return 0, nil, net.ErrClosed + return 0, net.ErrClosed case rx := <-ch: - return copy(b, rx), c.target6, nil + copied := copy(buffs[0], rx) + sizes[0] = copied + eps[0] = c.target6 + return 1, nil } } } -func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { - select { - case <-c.closeSignal: - return net.ErrClosed - default: - bc := make([]byte, len(b)) - copy(bc, b) - if ep.(ChannelEndpoint) == c.target4 { - *c.tx4 <- bc - } else if ep.(ChannelEndpoint) == c.target6 { - *c.tx6 <- bc - } else { - return os.ErrInvalid +func (c *ChannelBind) Send(buffs [][]byte, ep conn.Endpoint) error { + for _, b := range buffs { + select { + case <-c.closeSignal: + return net.ErrClosed + default: + bc := make([]byte, len(b)) + copy(bc, b) + if ep.(ChannelEndpoint) == c.target4 { + *c.tx4 <- bc + } else if ep.(ChannelEndpoint) == c.target6 { + *c.tx6 <- bc + } else { + return os.ErrInvalid + } } } return nil diff --git a/conn/conn.go b/conn/conn.go index 497b92a..8c0a827 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -15,10 +15,17 @@ import ( "strings" ) -// A ReceiveFunc receives a single inbound packet from the network. -// It writes the data into b. n is the length of the packet. -// ep is the remote endpoint. -type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error) +const ( + DefaultBatchSize = 1 // maximum number of packets handled per read and write +) + +// A ReceiveFunc receives at least one packet from the network and writes them +// into packets. On a successful read it returns the number of elements of +// sizes, packets, and endpoints that should be evaluated. Some elements of +// sizes may be zero, and callers should ignore them. Callers must pass a sizes +// and eps slice with a length greater than or equal to the length of packets. +// These lengths must not exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // @@ -38,11 +45,16 @@ type Bind interface { // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // Send writes a packet b to address ep. - Send(b []byte, ep Endpoint) error + // Send writes one or more packets in buffs to address ep. The length of + // buffs must not exceed BatchSize(). + Send(buffs [][]byte, ep Endpoint) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) + + // BatchSize is the number of buffers expected to be passed to + // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. + BatchSize() int } // BindSocketToInterface is implemented by Bind objects that support being diff --git a/conn/conn_test.go b/conn/conn_test.go new file mode 100644 index 0000000..7a6231d --- /dev/null +++ b/conn/conn_test.go @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "testing" +) + +func TestPrettyName(t *testing.T) { + var ( + recvFunc ReceiveFunc = func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + ) + + const want = "TestPrettyName" + + t.Run("ReceiveFunc.PrettyName", func(t *testing.T) { + if got := recvFunc.PrettyName(); got != want { + t.Errorf("PrettyName() = %v, want %v", got, want) + } + }) +} |