diff options
-rw-r--r-- | device/noise-protocol.go | 55 |
1 files changed, 38 insertions, 17 deletions
diff --git a/device/noise-protocol.go b/device/noise-protocol.go index ee327d2..6dcc831 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -7,6 +7,7 @@ package device import ( "errors" + "fmt" "sync" "time" @@ -16,14 +17,34 @@ import ( "golang.zx2c4.com/wireguard/tai64n" ) +type handshakeState int + +// TODO(crawshaw): add commentary describing each state and the transitions const ( - HandshakeZeroed = iota - HandshakeInitiationCreated - HandshakeInitiationConsumed - HandshakeResponseCreated - HandshakeResponseConsumed + handshakeZeroed = handshakeState(iota) + handshakeInitiationCreated + handshakeInitiationConsumed + handshakeResponseCreated + handshakeResponseConsumed ) +func (hs handshakeState) String() string { + switch hs { + case handshakeZeroed: + return "handshakeZeroed" + case handshakeInitiationCreated: + return "handshakeInitiationCreated" + case handshakeInitiationConsumed: + return "handshakeInitiationConsumed" + case handshakeResponseCreated: + return "handshakeResponseCreated" + case handshakeResponseConsumed: + return "handshakeResponseConsumed" + default: + return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) + } +} + const ( NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" @@ -95,7 +116,7 @@ type MessageCookieReply struct { } type Handshake struct { - state int + state handshakeState mutex sync.RWMutex hash [blake2s.Size]byte // hash value chainKey [blake2s.Size]byte // chain key @@ -135,7 +156,7 @@ func (h *Handshake) Clear() { setZero(h.chainKey[:]) setZero(h.hash[:]) h.localIndex = 0 - h.state = HandshakeZeroed + h.state = handshakeZeroed } func (h *Handshake) mixHash(data []byte) { @@ -221,7 +242,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.localIndex = msg.Sender handshake.mixHash(msg.Timestamp[:]) - handshake.state = HandshakeInitiationCreated + handshake.state = handshakeInitiationCreated return &msg, nil } @@ -316,7 +337,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { if now.After(handshake.lastInitiationConsumption) { handshake.lastInitiationConsumption = now } - handshake.state = HandshakeInitiationConsumed + handshake.state = handshakeInitiationConsumed handshake.mutex.Unlock() @@ -331,7 +352,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mutex.Lock() defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitiationConsumed { + if handshake.state != handshakeInitiationConsumed { return nil, errors.New("handshake initiation must be consumed first") } @@ -387,7 +408,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(msg.Empty[:]) }() - handshake.state = HandshakeResponseCreated + handshake.state = handshakeResponseCreated return &msg, nil } @@ -417,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake.mutex.RLock() defer handshake.mutex.RUnlock() - if handshake.state != HandshakeInitiationCreated { + if handshake.state != handshakeInitiationCreated { return false } @@ -478,7 +499,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender - handshake.state = HandshakeResponseConsumed + handshake.state = handshakeResponseConsumed handshake.mutex.Unlock() @@ -503,7 +524,7 @@ func (peer *Peer) BeginSymmetricSession() error { var sendKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte - if handshake.state == HandshakeResponseConsumed { + if handshake.state == handshakeResponseConsumed { KDF2( &sendKey, &recvKey, @@ -511,7 +532,7 @@ func (peer *Peer) BeginSymmetricSession() error { nil, ) isInitiator = true - } else if handshake.state == HandshakeResponseCreated { + } else if handshake.state == handshakeResponseCreated { KDF2( &recvKey, &sendKey, @@ -520,7 +541,7 @@ func (peer *Peer) BeginSymmetricSession() error { ) isInitiator = false } else { - return errors.New("invalid state for keypair derivation") + return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) } // zero handshake @@ -528,7 +549,7 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(handshake.chainKey[:]) setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.localEphemeral[:]) - peer.handshake.state = HandshakeZeroed + peer.handshake.state = handshakeZeroed // create AEAD instances |