summaryrefslogtreecommitdiffhomepage
path: root/device/noise-protocol.go
diff options
context:
space:
mode:
authorDavid Crawshaw <crawshaw@tailscale.com>2020-03-04 20:58:39 -0500
committerJason A. Donenfeld <Jason@zx2c4.com>2020-05-02 01:46:42 -0600
commitde374bfb44945e241d93ca821f35f6e3078e506b (patch)
tree9d44d90935b0f226ee30b98584da7ae998db05dc /device/noise-protocol.go
parent1a1c3d096888816c94cf1eb7c2747e83f008549f (diff)
device: give handshake state a type
And unexport handshake constants. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
Diffstat (limited to 'device/noise-protocol.go')
-rw-r--r--device/noise-protocol.go55
1 files changed, 38 insertions, 17 deletions
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index ee327d2..6dcc831 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -7,6 +7,7 @@ package device
import (
"errors"
+ "fmt"
"sync"
"time"
@@ -16,14 +17,34 @@ import (
"golang.zx2c4.com/wireguard/tai64n"
)
+type handshakeState int
+
+// TODO(crawshaw): add commentary describing each state and the transitions
const (
- HandshakeZeroed = iota
- HandshakeInitiationCreated
- HandshakeInitiationConsumed
- HandshakeResponseCreated
- HandshakeResponseConsumed
+ handshakeZeroed = handshakeState(iota)
+ handshakeInitiationCreated
+ handshakeInitiationConsumed
+ handshakeResponseCreated
+ handshakeResponseConsumed
)
+func (hs handshakeState) String() string {
+ switch hs {
+ case handshakeZeroed:
+ return "handshakeZeroed"
+ case handshakeInitiationCreated:
+ return "handshakeInitiationCreated"
+ case handshakeInitiationConsumed:
+ return "handshakeInitiationConsumed"
+ case handshakeResponseCreated:
+ return "handshakeResponseCreated"
+ case handshakeResponseConsumed:
+ return "handshakeResponseConsumed"
+ default:
+ return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
+ }
+}
+
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@@ -95,7 +116,7 @@ type MessageCookieReply struct {
}
type Handshake struct {
- state int
+ state handshakeState
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
@@ -135,7 +156,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
- h.state = HandshakeZeroed
+ h.state = handshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
@@ -221,7 +242,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
- handshake.state = HandshakeInitiationCreated
+ handshake.state = handshakeInitiationCreated
return &msg, nil
}
@@ -316,7 +337,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
- handshake.state = HandshakeInitiationConsumed
+ handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock()
@@ -331,7 +352,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitiationConsumed {
+ if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first")
}
@@ -387,7 +408,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Empty[:])
}()
- handshake.state = HandshakeResponseCreated
+ handshake.state = handshakeResponseCreated
return &msg, nil
}
@@ -417,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
- if handshake.state != HandshakeInitiationCreated {
+ if handshake.state != handshakeInitiationCreated {
return false
}
@@ -478,7 +499,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
- handshake.state = HandshakeResponseConsumed
+ handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock()
@@ -503,7 +524,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
- if handshake.state == HandshakeResponseConsumed {
+ if handshake.state == handshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
@@ -511,7 +532,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil,
)
isInitiator = true
- } else if handshake.state == HandshakeResponseCreated {
+ } else if handshake.state == handshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
@@ -520,7 +541,7 @@ func (peer *Peer) BeginSymmetricSession() error {
)
isInitiator = false
} else {
- return errors.New("invalid state for keypair derivation")
+ return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
}
// zero handshake
@@ -528,7 +549,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
- peer.handshake.state = HandshakeZeroed
+ peer.handshake.state = handshakeZeroed
// create AEAD instances