package main

import (
	"crypto/hmac"
	"crypto/rand"
	"golang.org/x/crypto/blake2s"
	"golang.org/x/crypto/chacha20poly1305"
	"sync"
	"time"
)

type CookieChecker struct {
	mutex sync.RWMutex
	mac1  struct {
		key [blake2s.Size]byte
	}
	mac2 struct {
		secret        [blake2s.Size]byte
		secretSet     time.Time
		encryptionKey [chacha20poly1305.KeySize]byte
	}
}

type CookieGenerator struct {
	mutex sync.RWMutex
	mac1  struct {
		key [blake2s.Size]byte
	}
	mac2 struct {
		cookie        [blake2s.Size128]byte
		cookieSet     time.Time
		hasLastMAC1   bool
		lastMAC1      [blake2s.Size128]byte
		encryptionKey [chacha20poly1305.KeySize]byte
	}
}

func (st *CookieChecker) Init(pk NoisePublicKey) {
	st.mutex.Lock()
	defer st.mutex.Unlock()

	// mac1 state

	func() {
		hsh, _ := blake2s.New256(nil)
		hsh.Write([]byte(WGLabelMAC1))
		hsh.Write(pk[:])
		hsh.Sum(st.mac1.key[:0])
	}()

	// mac2 state

	func() {
		hsh, _ := blake2s.New256(nil)
		hsh.Write([]byte(WGLabelCookie))
		hsh.Write(pk[:])
		hsh.Sum(st.mac2.encryptionKey[:0])
	}()

	st.mac2.secretSet = time.Time{}
}

func (st *CookieChecker) CheckMAC1(msg []byte) bool {
	st.mutex.RLock()
	defer st.mutex.RUnlock()

	size := len(msg)
	smac2 := size - blake2s.Size128
	smac1 := smac2 - blake2s.Size128

	var mac1 [blake2s.Size128]byte

	mac, _ := blake2s.New128(st.mac1.key[:])
	mac.Write(msg[:smac1])
	mac.Sum(mac1[:0])

	return hmac.Equal(mac1[:], msg[smac1:smac2])
}

func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
	st.mutex.RLock()
	defer st.mutex.RUnlock()

	if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
		return false
	}

	// derive cookie key

	var cookie [blake2s.Size128]byte
	func() {
		mac, _ := blake2s.New128(st.mac2.secret[:])
		mac.Write(src)
		mac.Sum(cookie[:0])
	}()

	// calculate mac of packet (including mac1)

	smac2 := len(msg) - blake2s.Size128

	var mac2 [blake2s.Size128]byte
	func() {
		mac, _ := blake2s.New128(cookie[:])
		mac.Write(msg[:smac2])
		mac.Sum(mac2[:0])
	}()

	return hmac.Equal(mac2[:], msg[smac2:])
}

func (st *CookieChecker) CreateReply(
	msg []byte,
	recv uint32,
	src []byte,
) (*MessageCookieReply, error) {

	st.mutex.RLock()

	// refresh cookie secret

	if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
		st.mutex.RUnlock()
		st.mutex.Lock()
		_, err := rand.Read(st.mac2.secret[:])
		if err != nil {
			st.mutex.Unlock()
			return nil, err
		}
		st.mac2.secretSet = time.Now()
		st.mutex.Unlock()
		st.mutex.RLock()
	}

	// derive cookie

	var cookie [blake2s.Size128]byte
	func() {
		mac, _ := blake2s.New128(st.mac2.secret[:])
		mac.Write(src)
		mac.Sum(cookie[:0])
	}()

	// encrypt cookie

	size := len(msg)

	smac2 := size - blake2s.Size128
	smac1 := smac2 - blake2s.Size128

	reply := new(MessageCookieReply)
	reply.Type = MessageCookieReplyType
	reply.Receiver = recv

	_, err := rand.Read(reply.Nonce[:])
	if err != nil {
		st.mutex.RUnlock()
		return nil, err
	}

	XChaCha20Poly1305Encrypt(
		reply.Cookie[:0],
		&reply.Nonce,
		cookie[:],
		msg[smac1:smac2],
		&st.mac2.encryptionKey,
	)

	st.mutex.RUnlock()

	return reply, nil
}

func (st *CookieGenerator) Init(pk NoisePublicKey) {
	st.mutex.Lock()
	defer st.mutex.Unlock()

	func() {
		hsh, _ := blake2s.New256(nil)
		hsh.Write([]byte(WGLabelMAC1))
		hsh.Write(pk[:])
		hsh.Sum(st.mac1.key[:0])
	}()

	func() {
		hsh, _ := blake2s.New256(nil)
		hsh.Write([]byte(WGLabelCookie))
		hsh.Write(pk[:])
		hsh.Sum(st.mac2.encryptionKey[:0])
	}()

	st.mac2.cookieSet = time.Time{}
}

func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
	st.mutex.Lock()
	defer st.mutex.Unlock()

	if !st.mac2.hasLastMAC1 {
		return false
	}

	var cookie [blake2s.Size128]byte

	_, err := XChaCha20Poly1305Decrypt(
		cookie[:0],
		&msg.Nonce,
		msg.Cookie[:],
		st.mac2.lastMAC1[:],
		&st.mac2.encryptionKey,
	)

	if err != nil {
		return false
	}

	st.mac2.cookieSet = time.Now()
	st.mac2.cookie = cookie
	return true
}

func (st *CookieGenerator) AddMacs(msg []byte) {

	size := len(msg)

	smac2 := size - blake2s.Size128
	smac1 := smac2 - blake2s.Size128

	mac1 := msg[smac1:smac2]
	mac2 := msg[smac2:]

	st.mutex.Lock()
	defer st.mutex.Unlock()

	// set mac1

	func() {
		mac, _ := blake2s.New128(st.mac1.key[:])
		mac.Write(msg[:smac1])
		mac.Sum(mac1[:0])
	}()
	copy(st.mac2.lastMAC1[:], mac1)
	st.mac2.hasLastMAC1 = true

	// set mac2

	if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime {
		return
	}

	func() {
		mac, _ := blake2s.New128(st.mac2.cookie[:])
		mac.Write(msg[:smac2])
		mac.Sum(mac2[:0])
	}()
}