summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/conn.go45
-rw-r--r--src/conn_linux.go35
-rw-r--r--src/device.go16
-rw-r--r--src/receive.go187
4 files changed, 151 insertions, 132 deletions
diff --git a/src/conn.go b/src/conn.go
index b2caffb..aa0b72b 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -34,6 +34,21 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err
}
+/* Must hold device and net lock
+ */
+func unsafeCloseUDPListener(device *Device) error {
+ netc := &device.net
+ if netc.bind != nil {
+ if err := netc.bind.Close(); err != nil {
+ return err
+ }
+ netc.bind = nil
+ netc.update.Broadcast()
+ }
+ return nil
+}
+
+// must inform all listeners
func UpdateUDPListener(device *Device) error {
device.mutex.Lock()
defer device.mutex.Unlock()
@@ -44,26 +59,22 @@ func UpdateUDPListener(device *Device) error {
// close existing sockets
- if netc.bind != nil {
- println("close bind")
- if err := netc.bind.Close(); err != nil {
- return err
- }
- netc.bind = nil
- println("closed")
+ if err := unsafeCloseUDPListener(device); err != nil {
+ return err
}
+ // wait for reader
+
// open new sockets
if device.tun.isUp.Get() {
- println("creat")
-
// bind to new port
var err error
netc.bind, netc.port, err = CreateUDPBind(netc.port)
if err != nil {
+ netc.bind = nil
return err
}
@@ -74,8 +85,6 @@ func UpdateUDPListener(device *Device) error {
return err
}
- println("okay")
-
// clear cached source addresses
for _, peer := range device.peers {
@@ -83,14 +92,20 @@ func UpdateUDPListener(device *Device) error {
peer.endpoint.value.ClearSrc()
peer.mutex.Unlock()
}
+
+ // inform readers of updated bind
+
+ netc.update.Broadcast()
}
return nil
}
func CloseUDPListener(device *Device) error {
- netc := &device.net
- netc.mutex.Lock()
- defer netc.mutex.Unlock()
- return netc.bind.Close()
+ device.mutex.Lock()
+ device.net.mutex.Lock()
+ err := unsafeCloseUDPListener(device)
+ device.net.mutex.Unlock()
+ device.mutex.Unlock()
+ return err
}
diff --git a/src/conn_linux.go b/src/conn_linux.go
index 8cda460..05f9347 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -7,8 +7,8 @@
package main
import (
+ "encoding/binary"
"errors"
- "fmt"
"golang.org/x/sys/unix"
"net"
"strconv"
@@ -37,6 +37,17 @@ type NativeBind struct {
sock6 int
}
+func htons(val uint16) uint16 {
+ var out [unsafe.Sizeof(val)]byte
+ binary.BigEndian.PutUint16(out[:], val)
+ return *((*uint16)(unsafe.Pointer(&out[0])))
+}
+
+func ntohs(val uint16) uint16 {
+ tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
+ return binary.BigEndian.Uint16((*tmp)[:])
+}
+
func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
var err error
var bind NativeBind
@@ -50,8 +61,6 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
if err != nil {
unix.Close(bind.sock6)
}
- println(bind.sock6)
- println(bind.sock4)
return bind, port, err
}
@@ -297,13 +306,11 @@ func (end *Endpoint) SetDst(s string) error {
return err
}
- fmt.Println(addr, err)
-
ipv4 := addr.IP.To4()
if ipv4 != nil {
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
dst.Family = unix.AF_INET
- dst.Port = uint16(addr.Port)
+ dst.Port = htons(uint16(addr.Port))
dst.Zero = [8]byte{}
copy(dst.Addr[:], ipv4)
end.ClearSrc()
@@ -318,7 +325,7 @@ func (end *Endpoint) SetDst(s string) error {
}
dst := &end.dst
dst.Family = unix.AF_INET6
- dst.Port = uint16(addr.Port)
+ dst.Port = htons(uint16(addr.Port))
dst.Flowinfo = 0
dst.Scope_id = zone
copy(dst.Addr[:], ipv6[:])
@@ -392,9 +399,6 @@ func send6(sock int, end *Endpoint, buff []byte) error {
}
func send4(sock int, end *Endpoint, buff []byte) error {
- println("send 4")
- println(end.DstToString())
- println(sock)
// construct message header
@@ -425,6 +429,7 @@ func send4(sock int, end *Endpoint, buff []byte) error {
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet4,
Control: (*byte)(unsafe.Pointer(&cmsg)),
+ Flags: 0,
}
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
@@ -437,10 +442,6 @@ func send4(sock int, end *Endpoint, buff []byte) error {
0,
)
- if errno == 0 {
- return nil
- }
-
// clear source and try again
if errno == unix.EINVAL {
@@ -454,6 +455,12 @@ func send4(sock int, end *Endpoint, buff []byte) error {
)
}
+ // errno = 0 is still an error instance
+
+ if errno == 0 {
+ return nil
+ }
+
return errno
}
diff --git a/src/device.go b/src/device.go
index 1aae448..a348c68 100644
--- a/src/device.go
+++ b/src/device.go
@@ -23,9 +23,10 @@ type Device struct {
}
net struct {
mutex sync.RWMutex
- bind UDPBind
- port uint16
- fwmark uint32
+ bind UDPBind // bind interface
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
+ update *sync.Cond // the bind was updated
}
mutex sync.RWMutex
privateKey NoisePrivateKey
@@ -38,8 +39,7 @@ type Device struct {
handshake chan QueueHandshakeElement
}
signal struct {
- stop chan struct{}
- updateBind chan struct{}
+ stop chan struct{}
}
underLoadUntil atomic.Value
ratelimiter Ratelimiter
@@ -163,6 +163,12 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.signal.stop = make(chan struct{})
+ // prepare net
+
+ device.net.port = 0
+ device.net.bind = nil
+ device.net.update = sync.NewCond(&device.net.mutex)
+
// start workers
for i := 0; i < runtime.NumCPU(); i += 1 {
diff --git a/src/receive.go b/src/receive.go
index 1f05b2f..cb53f80 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -99,135 +99,126 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
for {
- // wait for new conn
-
- logDebug.Println("Waiting for udp socket")
-
- select {
- case <-device.signal.stop:
- return
-
- case <-device.signal.updateBind:
-
- // fetch new socket
+ // wait for bind
+
+ logDebug.Println("Waiting for udp bind")
+ device.net.mutex.Lock()
+ device.net.update.Wait()
+ bind := device.net.bind
+ device.net.mutex.Unlock()
+ if bind == nil {
+ continue
+ }
- device.net.mutex.RLock()
- bind := device.net.bind
- device.net.mutex.RUnlock()
- if bind == nil {
- continue
- }
+ logDebug.Println("LISTEN\n\n\n")
- logDebug.Println("Listening for inbound packets")
+ // receive datagrams until conn is closed
- // receive datagrams until conn is closed
+ buffer := device.GetMessageBuffer()
- buffer := device.GetMessageBuffer()
+ var size int
+ var err error
- var size int
- var err error
+ for {
- for {
+ // read next datagram
- // read next datagram
+ var endpoint Endpoint
- var endpoint Endpoint
-
- switch IPVersion {
- case ipv4.Version:
- size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
- case ipv6.Version:
- size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
- default:
- return
- }
+ switch IPVersion {
+ case ipv4.Version:
+ size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
+ case ipv6.Version:
+ size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
+ default:
+ return
+ }
- if err != nil {
- break
- }
+ if err != nil {
+ break
+ }
- if size < MinMessageSize {
- continue
- }
+ if size < MinMessageSize {
+ continue
+ }
- // check size of packet
+ // check size of packet
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ packet := buffer[:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- var okay bool
+ var okay bool
- switch msgType {
+ switch msgType {
- // check if transport
+ // check if transport
- case MessageTransportType:
+ case MessageTransportType:
- // check size
+ // check size
- if len(packet) < MessageTransportType {
- continue
- }
+ if len(packet) < MessageTransportType {
+ continue
+ }
- // lookup key pair
+ // lookup key pair
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indices.Lookup(receiver)
- keyPair := value.keyPair
- if keyPair == nil {
- continue
- }
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indices.Lookup(receiver)
+ keyPair := value.keyPair
+ if keyPair == nil {
+ continue
+ }
- // check key-pair expiry
+ // check key-pair expiry
- if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
- continue
- }
+ if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- // create work element
+ // create work element
- peer := value.peer
- elem := &QueueInboundElement{
- packet: packet,
- buffer: buffer,
- keyPair: keyPair,
- dropped: AtomicFalse,
- }
- elem.mutex.Lock()
+ peer := value.peer
+ elem := &QueueInboundElement{
+ packet: packet,
+ buffer: buffer,
+ keyPair: keyPair,
+ dropped: AtomicFalse,
+ }
+ elem.mutex.Lock()
- // add to decryption queues
+ // add to decryption queues
- device.addToDecryptionQueue(device.queue.decryption, elem)
- device.addToInboundQueue(peer.queue.inbound, elem)
- buffer = device.GetMessageBuffer()
- continue
+ device.addToDecryptionQueue(device.queue.decryption, elem)
+ device.addToInboundQueue(peer.queue.inbound, elem)
+ buffer = device.GetMessageBuffer()
+ continue
- // otherwise it is a fixed size & handshake related packet
+ // otherwise it is a fixed size & handshake related packet
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ okay = len(packet) == MessageInitiationSize
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ okay = len(packet) == MessageResponseSize
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
- }
+ case MessageCookieReplyType:
+ okay = len(packet) == MessageCookieReplySize
+ }
- if okay {
- device.addToHandshakeQueue(
- device.queue.handshake,
- QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- endpoint: endpoint,
- },
- )
- buffer = device.GetMessageBuffer()
- }
+ if okay {
+ device.addToHandshakeQueue(
+ device.queue.handshake,
+ QueueHandshakeElement{
+ msgType: msgType,
+ buffer: buffer,
+ packet: packet,
+ endpoint: endpoint,
+ },
+ )
+ buffer = device.GetMessageBuffer()
}
}
}