summaryrefslogtreecommitdiffhomepage
path: root/device/noise-protocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/noise-protocol.go')
-rw-r--r--device/noise-protocol.go63
1 files changed, 32 insertions, 31 deletions
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 117e960..e8f6145 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -175,8 +175,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
- errZeroECDHResult := errors.New("ECDH returned all zeros")
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- if isZero(ss[:]) {
- return nil, errZeroECDHResult
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
}
var key [chacha20poly1305.KeySize]byte
KDF2(
@@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errZeroECDHResult
+ return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
@@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
- var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- if isZero(ss[:]) {
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
handshake.state = handshakeResponseCreated
@@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}