diff options
-rw-r--r-- | src/conn.go | 16 | ||||
-rw-r--r-- | src/conn_linux.go | 14 | ||||
-rw-r--r-- | src/device.go | 12 | ||||
-rw-r--r-- | src/receive.go | 13 |
4 files changed, 32 insertions, 23 deletions
diff --git a/src/conn.go b/src/conn.go index aa0b72b..0347262 100644 --- a/src/conn.go +++ b/src/conn.go @@ -37,15 +37,14 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { /* Must hold device and net lock */ func unsafeCloseUDPListener(device *Device) error { + var err error netc := &device.net if netc.bind != nil { - if err := netc.bind.Close(); err != nil { - return err - } + err = netc.bind.Close() netc.bind = nil - netc.update.Broadcast() + netc.update.Add(1) } - return nil + return err } // must inform all listeners @@ -63,7 +62,7 @@ func UpdateUDPListener(device *Device) error { return err } - // wait for reader + // assumption: netc.update WaitGroup should be exactly 1 // open new sockets @@ -93,9 +92,10 @@ func UpdateUDPListener(device *Device) error { peer.mutex.Unlock() } - // inform readers of updated bind + // decrease waitgroup to 0 - netc.update.Broadcast() + device.log.Debug.Println("UDP bind has been updated") + netc.update.Done() } return nil diff --git a/src/conn_linux.go b/src/conn_linux.go index 05f9347..383ff7e 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -84,9 +84,15 @@ func (bind NativeBind) SetMark(value uint32) error { ) } +func closeUnblock(fd int) error { + // shutdown to unblock readers + unix.Shutdown(fd, unix.SHUT_RD) + return unix.Close(fd) +} + func (bind NativeBind) Close() error { - err1 := unix.Close(bind.sock6) - err2 := unix.Close(bind.sock4) + err1 := closeUnblock(bind.sock6) + err2 := closeUnblock(bind.sock4) if err1 != nil { return err1 } @@ -125,13 +131,13 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { switch addr.Family { case unix.AF_INET6: - udpAddr.Port = int(addr.Port) + udpAddr.Port = int(ntohs(addr.Port)) udpAddr.IP = addr.Addr[:] return udpAddr.String() case unix.AF_INET: ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) - udpAddr.Port = int(ptr.Port) + udpAddr.Port = int(ntohs(ptr.Port)) udpAddr.IP = net.IPv4( ptr.Addr[0], ptr.Addr[1], diff --git a/src/device.go b/src/device.go index a348c68..033a387 100644 --- a/src/device.go +++ b/src/device.go @@ -23,10 +23,10 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) - update *sync.Cond // the bind was updated + bind UDPBind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + update sync.WaitGroup // the bind was updated (acting as a barrier) } mutex sync.RWMutex privateKey NoisePrivateKey @@ -167,7 +167,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.net.port = 0 device.net.bind = nil - device.net.update = sync.NewCond(&device.net.mutex) + device.net.update.Add(1) // start workers @@ -209,9 +209,11 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { + device.log.Info.Println("Closing device") device.RemoveAllPeers() close(device.signal.stop) CloseUDPListener(device) + device.tun.device.Close() } func (device *Device) WaitChannel() chan struct{} { diff --git a/src/receive.go b/src/receive.go index cb53f80..3e88be3 100644 --- a/src/receive.go +++ b/src/receive.go @@ -95,23 +95,22 @@ func (device *Device) addToHandshakeQueue( func (device *Device) RoutineReceiveIncomming(IPVersion int) { logDebug := device.log.Debug - logDebug.Println("Routine, receive incomming, started") + logDebug.Println("Routine, receive incomming, IP version:", IPVersion) for { // wait for bind - logDebug.Println("Waiting for udp bind") - device.net.mutex.Lock() + logDebug.Println("Waiting for UDP socket, IP version:", IPVersion) + device.net.update.Wait() + device.net.mutex.RLock() bind := device.net.bind - device.net.mutex.Unlock() + device.net.mutex.RUnlock() if bind == nil { continue } - logDebug.Println("LISTEN\n\n\n") - // receive datagrams until conn is closed buffer := device.GetMessageBuffer() @@ -427,6 +426,8 @@ func (device *Device) RoutineHandshake() { err = peer.SendBuffer(packet) if err == nil { peer.TimerAnyAuthenticatedPacketTraversal() + } else { + logError.Println("Failed to send response to:", peer.String(), err) } case MessageResponseType: |