summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--noise-protocol.go33
-rw-r--r--noise_test.go8
-rw-r--r--receive.go33
-rw-r--r--send.go6
4 files changed, 44 insertions, 36 deletions
diff --git a/noise-protocol.go b/noise-protocol.go
index 82d553e..f72dcc4 100644
--- a/noise-protocol.go
+++ b/noise-protocol.go
@@ -319,6 +319,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.mutex.Unlock()
+ setZero(hash[:])
+ setZero(chainKey[:])
+
return peer
}
@@ -362,7 +365,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixKey(ss[:])
}()
- // add preshared key (psk)
+ // add preshared key
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
@@ -457,7 +460,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
- device.log.Debug.Println("failed to open")
return false
}
mixHash(&hash, &hash, msg.Empty[:])
@@ -485,10 +487,10 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return lookup.peer
}
-/* Derives a new key-pair from the current handshake state
+/* Derives a new keypair from the current handshake state
*
*/
-func (peer *Peer) NewKeypair() *Keypair {
+func (peer *Peer) DeriveNewKeypair() error {
device := peer.device
handshake := &peer.handshake
handshake.mutex.Lock()
@@ -517,12 +519,13 @@ func (peer *Peer) NewKeypair() *Keypair {
)
isInitiator = false
} else {
- return nil
+ return errors.New("invalid state for keypair derivation")
}
// zero handshake
setZero(handshake.chainKey[:])
+ setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed
@@ -576,5 +579,23 @@ func (peer *Peer) NewKeypair() *Keypair {
}
kp.mutex.Unlock()
- return keypair
+ return nil
+}
+
+func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
+ kp := &peer.keypairs
+ if kp.next != receivedKeypair {
+ return false
+ }
+ kp.mutex.Lock()
+ defer kp.mutex.Unlock()
+ if kp.next != receivedKeypair {
+ return false
+ }
+ old := kp.previous
+ kp.previous = kp.current
+ peer.device.DeleteKeypair(old)
+ kp.current = kp.next
+ kp.next = nil
+ return true
}
diff --git a/noise_test.go b/noise_test.go
index 37bfb94..ce32097 100644
--- a/noise_test.go
+++ b/noise_test.go
@@ -102,15 +102,15 @@ func TestNoiseHandshake(t *testing.T) {
t.Log("deriving keys")
- key1 := peer1.NewKeypair()
- key2 := peer2.NewKeypair()
+ key1 := peer1.DeriveNewKeypair()
+ key2 := peer2.DeriveNewKeypair()
if key1 == nil {
- t.Fatal("failed to dervice key-pair for peer 1")
+ t.Fatal("failed to dervice keypair for peer 1")
}
if key2 == nil {
- t.Fatal("failed to dervice key-pair for peer 2")
+ t.Fatal("failed to dervice keypair for peer 2")
}
// encrypting / decryption test
diff --git a/receive.go b/receive.go
index 32ff512..64253e6 100644
--- a/receive.go
+++ b/receive.go
@@ -189,7 +189,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
continue
}
- // check key-pair expiry
+ // check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
@@ -475,7 +475,7 @@ func (device *Device) RoutineHandshake() {
continue
}
- if peer.NewKeypair() == nil {
+ if peer.DeriveNewKeypair() != nil {
continue
}
@@ -532,9 +532,9 @@ func (device *Device) RoutineHandshake() {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
- // derive key-pair
+ // derive keypair
- if peer.NewKeypair() == nil {
+ if peer.DeriveNewKeypair() != nil {
continue
}
@@ -597,25 +597,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
- // check if using new key-pair
-
- kp := &peer.keypairs
- if kp.next == elem.keypair {
- kp.mutex.Lock()
- if kp.next != elem.keypair {
- kp.mutex.Unlock()
- } else {
- old := kp.previous
- kp.previous = kp.current
- device.DeleteKeypair(old)
- kp.current = kp.next
- kp.next = nil
- kp.mutex.Unlock()
- peer.timersHandshakeComplete()
- select {
- case peer.signals.newKeypairArrived <- struct{}{}:
- default:
- }
+ // check if using new keypair
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.timersHandshakeComplete()
+ select {
+ case peer.signals.newKeypairArrived <- struct{}{}:
+ default:
}
}
diff --git a/send.go b/send.go
index 35e0d00..a8ec28c 100644
--- a/send.go
+++ b/send.go
@@ -47,7 +47,7 @@ type QueueOutboundElement struct {
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
- keypair *Keypair // key-pair for encryption
+ keypair *Keypair // keypair for encryption
peer *Peer // related peer
}
@@ -306,11 +306,11 @@ func (peer *Peer) RoutineNonce() {
peer.SendHandshakeInitiation(false)
- logDebug.Println(peer, ": Awaiting key-pair")
+ logDebug.Println(peer, ": Awaiting keypair")
select {
case <-peer.signals.newKeypairArrived:
- logDebug.Println(peer, ": Obtained awaited key-pair")
+ logDebug.Println(peer, ": Obtained awaited keypair")
case <-peer.signals.flushNonceQueue:
for {
select {