diff options
Diffstat (limited to 'device/peer.go')
-rw-r--r-- | device/peer.go | 50 |
1 files changed, 35 insertions, 15 deletions
diff --git a/device/peer.go b/device/peer.go index 2fb5da6..47a2f14 100644 --- a/device/peer.go +++ b/device/peer.go @@ -17,17 +17,20 @@ import ( type Peer struct { isRunning atomic.Bool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer keypairs Keypairs handshake Handshake device *Device - endpoint conn.Endpoint stopping sync.WaitGroup // routines pending stop txBytes atomic.Uint64 // bytes send to peer (endpoint) rxBytes atomic.Uint64 // bytes received from peer lastHandshakeNano atomic.Int64 // nano seconds since epoch - disableRoaming bool + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool + } timers struct { retransmitHandshake *Timer @@ -74,8 +77,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // create peer peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device @@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.mutex.Unlock() // reset endpoint - peer.endpoint = nil + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() // init timers peer.timersInit() @@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { var totalLen uint64 for _, b := range buffers { @@ -267,10 +277,20 @@ func (peer *Peer) Stop() { } func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { - if peer.disableRoaming { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} + +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } |