diff options
-rw-r--r-- | src/conn.go | 7 | ||||
-rw-r--r-- | src/conn_linux.go | 43 | ||||
-rw-r--r-- | src/main.go | 5 | ||||
-rw-r--r-- | src/peer.go | 11 | ||||
-rw-r--r-- | src/send.go | 12 | ||||
-rw-r--r-- | src/uapi.go | 2 |
6 files changed, 54 insertions, 26 deletions
diff --git a/src/conn.go b/src/conn.go index 012e24e..b2caffb 100644 --- a/src/conn.go +++ b/src/conn.go @@ -45,15 +45,20 @@ func UpdateUDPListener(device *Device) error { // close existing sockets if netc.bind != nil { + println("close bind") if err := netc.bind.Close(); err != nil { return err } + netc.bind = nil + println("closed") } // open new sockets if device.tun.isUp.Get() { + println("creat") + // bind to new port var err error @@ -69,6 +74,8 @@ func UpdateUDPListener(device *Device) error { return err } + println("okay") + // clear cached source addresses for _, peer := range device.peers { diff --git a/src/conn_linux.go b/src/conn_linux.go index 51ca4f3..8cda460 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -50,10 +50,12 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { if err != nil { unix.Close(bind.sock6) } - return &bind, port, err + println(bind.sock6) + println(bind.sock4) + return bind, port, err } -func (bind *NativeBind) SetMark(value uint32) error { +func (bind NativeBind) SetMark(value uint32) error { err := unix.SetsockoptInt( bind.sock6, unix.SOL_SOCKET, @@ -73,7 +75,7 @@ func (bind *NativeBind) SetMark(value uint32) error { ) } -func (bind *NativeBind) Close() error { +func (bind NativeBind) Close() error { err1 := unix.Close(bind.sock6) err2 := unix.Close(bind.sock4) if err1 != nil { @@ -82,7 +84,7 @@ func (bind *NativeBind) Close() error { return err2 } -func (bind *NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { +func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { return receive6( bind.sock6, buff, @@ -90,7 +92,7 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { +func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { return receive4( bind.sock4, buff, @@ -98,7 +100,7 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *NativeBind) Send(buff []byte, end *Endpoint) error { +func (bind NativeBind) Send(buff []byte, end *Endpoint) error { switch end.dst.Family { case unix.AF_INET6: return send6(bind.sock6, end, buff) @@ -236,7 +238,7 @@ func create6(port uint16) (int, uint16, error) { // create socket fd, err := unix.Socket( - unix.AF_INET, + unix.AF_INET6, unix.SOCK_DGRAM, 0, ) @@ -342,7 +344,7 @@ func send6(sock int, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IPV6, Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, }, unix.Inet6Pktinfo{ Addr: end.src.Addr, @@ -368,15 +370,31 @@ func send6(sock int, end *Endpoint, buff []byte) error { uintptr(unsafe.Pointer(&msghdr)), 0, ) + + if errno == 0 { + return nil + } + + // clear src and retry + if errno == unix.EINVAL { end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + _, _, errno = unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) } + return errno } func send4(sock int, end *Endpoint, buff []byte) error { println("send 4") println(end.DstToString()) + println(sock) // construct message header @@ -393,7 +411,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IP, Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, }, unix.Inet4Pktinfo{ Spec_dst: src4.src.Addr, @@ -419,10 +437,11 @@ func send4(sock int, end *Endpoint, buff []byte) error { 0, ) - println(sock) - fmt.Println(errno) + if errno == 0 { + return nil + } - // clear source cache and try again + // clear source and try again if errno == unix.EINVAL { end.ClearSrc() diff --git a/src/main.go b/src/main.go index 5aaed9b..05d56eb 100644 --- a/src/main.go +++ b/src/main.go @@ -84,7 +84,10 @@ func main() { logInfo := device.log.Info logError := device.log.Error - logInfo.Println("Starting device") + logDebug := device.log.Debug + + logInfo.Println("Device started") + logDebug.Println("Debug log enabled") // start configuration lister diff --git a/src/peer.go b/src/peer.go index f24dcd8..a98fc97 100644 --- a/src/peer.go +++ b/src/peer.go @@ -138,6 +138,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +func (peer *Peer) SendBuffer(buffer []byte) error { + peer.device.net.mutex.RLock() + defer peer.device.net.mutex.RUnlock() + peer.mutex.RLock() + defer peer.mutex.RUnlock() + if !peer.endpoint.set { + return errors.New("No known endpoint for peer") + } + return peer.device.net.bind.Send(buffer, &peer.endpoint.value) +} + /* Returns a short string identification for logging */ func (peer *Peer) String() string { diff --git a/src/send.go b/src/send.go index e37a736..52872f6 100644 --- a/src/send.go +++ b/src/send.go @@ -2,7 +2,6 @@ package main import ( "encoding/binary" - "errors" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -105,17 +104,6 @@ func addToEncryptionQueue( } } -func (peer *Peer) SendBuffer(buffer []byte) error { - peer.device.net.mutex.RLock() - defer peer.device.net.mutex.RUnlock() - peer.mutex.RLock() - defer peer.mutex.RUnlock() - if !peer.endpoint.set { - return errors.New("No known endpoint for peer") - } - return peer.device.net.bind.Send(buffer, &peer.endpoint.value) -} - /* Reads packets from the TUN and inserts * into nonce queue for peer * diff --git a/src/uapi.go b/src/uapi.go index accffd1..5098e3d 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -135,7 +135,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "listen_port": port, err := strconv.ParseUint(value, 10, 16) if err != nil { - logError.Println("Failed to set listen_port:", err) + logError.Println("Failed to parse listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } device.net.port = uint16(port) |