summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/conn.go21
-rw-r--r--src/conn_linux.go12
-rw-r--r--src/device.go2
-rw-r--r--src/main.go2
-rw-r--r--src/peer.go24
-rw-r--r--src/receive.go18
-rw-r--r--src/send.go19
-rw-r--r--src/timers.go2
-rw-r--r--src/tun.go4
-rw-r--r--src/uapi.go63
10 files changed, 84 insertions, 83 deletions
diff --git a/src/conn.go b/src/conn.go
index db4020d..012e24e 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -34,15 +34,20 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err
}
-func ListeningUpdate(device *Device) error {
+func UpdateUDPListener(device *Device) error {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
// close existing sockets
- if err := device.net.bind.Close(); err != nil {
- return err
+ if netc.bind != nil {
+ if err := netc.bind.Close(); err != nil {
+ return err
+ }
}
// open new sockets
@@ -64,13 +69,19 @@ func ListeningUpdate(device *Device) error {
return err
}
- // TODO: clear endpoint (src) caches
+ // clear cached source addresses
+
+ for _, peer := range device.peers {
+ peer.mutex.Lock()
+ peer.endpoint.value.ClearSrc()
+ peer.mutex.Unlock()
+ }
}
return nil
}
-func ListeningClose(device *Device) error {
+func CloseUDPListener(device *Device) error {
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
diff --git a/src/conn_linux.go b/src/conn_linux.go
index 8942b03..4a5a3f0 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -133,7 +133,7 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
}
}
-func (end *Endpoint) DestinationIP() net.IP {
+func (end *Endpoint) DstIP() net.IP {
switch end.dst.Family {
case unix.AF_INET6:
return end.dst.Addr[:]
@@ -150,20 +150,24 @@ func (end *Endpoint) DestinationIP() net.IP {
}
}
-func (end *Endpoint) SourceToBytes() []byte {
+func (end *Endpoint) SrcToBytes() []byte {
ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:]
}
-func (end *Endpoint) SourceToString() string {
+func (end *Endpoint) SrcToString() string {
return sockaddrToString(end.src)
}
-func (end *Endpoint) DestinationToString() string {
+func (end *Endpoint) DstToString() string {
return sockaddrToString(end.dst)
}
+func (end *Endpoint) ClearDst() {
+ end.dst = unix.RawSockaddrInet6{}
+}
+
func (end *Endpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{}
}
diff --git a/src/device.go b/src/device.go
index d1e0685..1aae448 100644
--- a/src/device.go
+++ b/src/device.go
@@ -205,7 +205,7 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.RemoveAllPeers()
close(device.signal.stop)
- ListeningClose(device)
+ CloseUDPListener(device)
}
func (device *Device) WaitChannel() chan struct{} {
diff --git a/src/main.go b/src/main.go
index a05dbba..5aaed9b 100644
--- a/src/main.go
+++ b/src/main.go
@@ -14,8 +14,6 @@ func printUsage() {
}
func main() {
- test()
-
// parse arguments
var foreground bool
diff --git a/src/peer.go b/src/peer.go
index 791c091..f24dcd8 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -14,9 +14,12 @@ type Peer struct {
persistentKeepaliveInterval uint64
keyPairs KeyPairs
handshake Handshake
- endpoint Endpoint
device *Device
- stats struct {
+ endpoint struct {
+ set bool // has a known endpoint been discovered
+ value Endpoint // source / destination cache
+ }
+ stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
@@ -105,6 +108,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock()
+ // reset endpoint
+
+ peer.endpoint.set = false
+ peer.endpoint.value.ClearDst()
+ peer.endpoint.value.ClearSrc()
+
// prepare queuing
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
@@ -129,11 +138,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
+/* Returns a short string identification for logging
+ */
func (peer *Peer) String() string {
+ if !peer.endpoint.set {
+ return fmt.Sprintf(
+ "peer(%d unknown %s)",
+ peer.id,
+ base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
+ )
+ }
return fmt.Sprintf(
"peer(%d %s %s)",
peer.id,
- peer.endpoint.DestinationToString(),
+ peer.endpoint.value.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
diff --git a/src/receive.go b/src/receive.go
index 664f1ba..1f05b2f 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -331,7 +331,7 @@ func (device *Device) RoutineHandshake() {
return
}
- srcBytes := elem.endpoint.SourceToBytes()
+ srcBytes := elem.endpoint.SrcToBytes()
if device.IsUnderLoad() {
// verify MAC2 field
@@ -340,8 +340,7 @@ func (device *Device) RoutineHandshake() {
// construct cookie reply
- logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString())
-
+ logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString())
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil {
@@ -365,9 +364,7 @@ func (device *Device) RoutineHandshake() {
// check ratelimiter
- if !device.ratelimiter.Allow(
- elem.endpoint.DestinationIP(),
- ) {
+ if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
continue
}
}
@@ -398,7 +395,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil {
logInfo.Println(
"Recieved invalid initiation message from",
- elem.endpoint.DestinationToString(),
+ elem.endpoint.DstToString(),
)
continue
}
@@ -412,7 +409,8 @@ func (device *Device) RoutineHandshake() {
// TODO: Discover destination address also, only update on change
peer.mutex.Lock()
- peer.endpoint = elem.endpoint
+ peer.endpoint.set = true
+ peer.endpoint.value = elem.endpoint
peer.mutex.Unlock()
// create response
@@ -435,7 +433,7 @@ func (device *Device) RoutineHandshake() {
// send response
- _, err = peer.SendBuffer(packet)
+ err = peer.SendBuffer(packet)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
}
@@ -458,7 +456,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil {
logInfo.Println(
"Recieved invalid response message from",
- elem.endpoint.DestinationToString(),
+ elem.endpoint.DstToString(),
)
continue
}
diff --git a/src/send.go b/src/send.go
index 5c88ead..e37a736 100644
--- a/src/send.go
+++ b/src/send.go
@@ -105,24 +105,15 @@ func addToEncryptionQueue(
}
}
-func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
+func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
-
peer.mutex.RLock()
defer peer.mutex.RUnlock()
-
- endpoint := peer.endpoint
- if endpoint == nil {
- return 0, errors.New("No known endpoint for peer")
+ if !peer.endpoint.set {
+ return errors.New("No known endpoint for peer")
}
-
- conn := peer.device.net.conn
- if conn == nil {
- return 0, errors.New("No UDP socket for device")
- }
-
- return conn.WriteToUDP(buffer, endpoint)
+ return peer.device.net.bind.Send(buffer, &peer.endpoint.value)
}
/* Reads packets from the TUN and inserts
@@ -343,7 +334,7 @@ func (peer *Peer) RoutineSequentialSender() {
// send message and return buffer to pool
length := uint64(len(elem.packet))
- _, err := peer.SendBuffer(elem.packet)
+ err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logDebug.Println("Failed to send authenticated packet to peer", peer.String())
diff --git a/src/timers.go b/src/timers.go
index 99695ba..2a94005 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -288,7 +288,7 @@ func (peer *Peer) RoutineHandshakeInitiator() {
packet := writer.Bytes()
peer.mac.AddMacs(packet)
- _, err = peer.SendBuffer(packet)
+ err = peer.SendBuffer(packet)
if err != nil {
logError.Println(
"Failed to send handshake initiation message to",
diff --git a/src/tun.go b/src/tun.go
index 8e8c759..9eed987 100644
--- a/src/tun.go
+++ b/src/tun.go
@@ -47,7 +47,7 @@ func (device *Device) RoutineTUNEventReader() {
if !device.tun.isUp.Get() {
logInfo.Println("Interface set up")
device.tun.isUp.Set(true)
- updateUDPConn(device)
+ UpdateUDPListener(device)
}
}
@@ -55,7 +55,7 @@ func (device *Device) RoutineTUNEventReader() {
if device.tun.isUp.Get() {
logInfo.Println("Interface set down")
device.tun.isUp.Set(false)
- closeUDPConn(device)
+ CloseUDPListener(device)
}
}
}
diff --git a/src/uapi.go b/src/uapi.go
index 7d08e56..2de26ee 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("private_key=" + device.privateKey.ToHex())
}
- if device.net.addr != nil {
- send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
+ if device.net.port != 0 {
+ send(fmt.Sprintf("listen_port=%d", device.net.port))
}
+
if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
}
@@ -52,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
- if peer.endpoint != nil {
- send("endpoint=" + peer.endpoint.String())
+ if peer.endpoint.set {
+ send("endpoint=" + peer.endpoint.value.DstToString())
}
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@@ -137,53 +138,24 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalid}
}
-
- addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
- if err != nil {
- logError.Println("Failed to set listen_port:", err)
- return &IPCError{Code: ipcErrorInvalid}
- }
-
- device.net.mutex.Lock()
- device.net.addr = addr
- device.net.mutex.Unlock()
-
- err = updateUDPConn(device)
- if err != nil {
+ device.net.port = uint16(port)
+ if err := UpdateUDPListener(device); err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse}
}
- // TODO: Clear source address of all peers
-
case "fwmark":
fwmark, err := strconv.ParseUint(value, 10, 32)
if err != nil {
logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid}
}
-
device.net.mutex.Lock()
- if fwmark > 0 || device.net.fwmark > 0 {
- device.net.fwmark = uint32(fwmark)
- err := SetMark(
- device.net.conn,
- device.net.fwmark,
- )
- if err != nil {
- logError.Println("Failed to set fwmark:", err)
- device.net.mutex.Unlock()
- return &IPCError{Code: ipcErrorIO}
- }
-
- // TODO: Clear source address of all peers
- }
+ device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock()
case "public_key":
-
// switch to peer configuration
-
deviceConfig = false
case "replace_peers":
@@ -218,7 +190,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock()
if device.publicKey.Equals(pubKey) {
- // create dummy instance
+ // create dummy instance (not added to device)
peer = &Peer{}
dummy = true
@@ -244,6 +216,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "remove":
+
+ // remove currently selected peer from device
+
if value != "true" {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
@@ -256,6 +231,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
dummy = true
case "preshared_key":
+
+ // update PSK
+
peer.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock()
@@ -265,14 +243,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "endpoint":
- addr, err := parseEndpoint(value)
+
+ // set endpoint destination and reset handshake timer
+
+ peer.mutex.Lock()
+ err := peer.endpoint.value.Set(value)
+ peer.endpoint.set = (err == nil)
+ peer.mutex.Unlock()
if err != nil {
logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid}
}
- peer.mutex.Lock()
- peer.endpoint = addr
- peer.mutex.Unlock()
signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval":