diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 149 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter_test.go | 98 | ||||
-rw-r--r-- | internal/tai64n/tai64n.go | 26 | ||||
-rw-r--r-- | internal/tai64n/tai64n_test.go | 21 | ||||
-rw-r--r-- | internal/xchacha20poly1305/xchacha20.go | 169 | ||||
-rw-r--r-- | internal/xchacha20poly1305/xchacha20_test.go | 96 |
6 files changed, 559 insertions, 0 deletions
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go new file mode 100644 index 0000000..006900a --- /dev/null +++ b/internal/ratelimiter/ratelimiter.go @@ -0,0 +1,149 @@ +package ratelimiter + +/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ + +/* This file contains a port of the rate-limiter from the linux kernel version */ + +import ( + "net" + "sync" + "time" +) + +const ( + packetsPerSecond = 20 + packetsBurstable = 5 + garbageCollectTime = time.Second + packetCost = 1000000000 / packetsPerSecond + maxTokens = packetCost * packetsBurstable +) + +type RatelimiterEntry struct { + mutex sync.Mutex + lastTime time.Time + tokens int64 +} + +type Ratelimiter struct { + mutex sync.RWMutex + stop chan struct{} + tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry + tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry +} + +func (rate *Ratelimiter) Close() { + rate.mutex.Lock() + defer rate.mutex.Unlock() + + if rate.stop != nil { + close(rate.stop) + } +} + +func (rate *Ratelimiter) Init() { + rate.mutex.Lock() + defer rate.mutex.Unlock() + + // stop any ongoing garbage collection routine + + if rate.stop != nil { + close(rate.stop) + } + + rate.stop = make(chan struct{}) + rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) + rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + + // start garbage collection routine + + go func() { + timer := time.NewTimer(time.Second) + for { + select { + case <-rate.stop: + return + case <-timer.C: + func() { + rate.mutex.Lock() + defer rate.mutex.Unlock() + + for key, entry := range rate.tableIPv4 { + entry.mutex.Lock() + if time.Now().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv4, key) + } + entry.mutex.Unlock() + } + + for key, entry := range rate.tableIPv6 { + entry.mutex.Lock() + if time.Now().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv6, key) + } + entry.mutex.Unlock() + } + }() + timer.Reset(time.Second) + } + } + }() +} + +func (rate *Ratelimiter) Allow(ip net.IP) bool { + var entry *RatelimiterEntry + var KeyIPv4 [net.IPv4len]byte + var KeyIPv6 [net.IPv6len]byte + + // lookup entry + + IPv4 := ip.To4() + IPv6 := ip.To16() + + rate.mutex.RLock() + + if IPv4 != nil { + copy(KeyIPv4[:], IPv4) + entry = rate.tableIPv4[KeyIPv4] + } else { + copy(KeyIPv6[:], IPv6) + entry = rate.tableIPv6[KeyIPv6] + } + + rate.mutex.RUnlock() + + // make new entry if not found + + if entry == nil { + rate.mutex.Lock() + entry = new(RatelimiterEntry) + entry.tokens = maxTokens - packetCost + entry.lastTime = time.Now() + if IPv4 != nil { + rate.tableIPv4[KeyIPv4] = entry + } else { + rate.tableIPv6[KeyIPv6] = entry + } + rate.mutex.Unlock() + return true + } + + // add tokens to entry + + entry.mutex.Lock() + now := time.Now() + entry.tokens += now.Sub(entry.lastTime).Nanoseconds() + entry.lastTime = now + if entry.tokens > maxTokens { + entry.tokens = maxTokens + } + + // subtract cost of packet + + if entry.tokens > packetCost { + entry.tokens -= packetCost + entry.mutex.Unlock() + return true + } + entry.mutex.Unlock() + return false +} diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go new file mode 100644 index 0000000..37339ee --- /dev/null +++ b/internal/ratelimiter/ratelimiter_test.go @@ -0,0 +1,98 @@ +package ratelimiter + +import ( + "net" + "testing" + "time" +) + +type RatelimiterResult struct { + allowed bool + text string + wait time.Duration +} + +func TestRatelimiter(t *testing.T) { + + var ratelimiter Ratelimiter + var expectedResults []RatelimiterResult + + Nano := func(nano int64) time.Duration { + return time.Nanosecond * time.Duration(nano) + } + + Add := func(res RatelimiterResult) { + expectedResults = append( + expectedResults, + res, + ) + } + + for i := 0; i < packetsBurstable; i++ { + Add(RatelimiterResult{ + allowed: true, + text: "inital burst", + }) + } + + Add(RatelimiterResult{ + allowed: false, + text: "after burst", + }) + + Add(RatelimiterResult{ + allowed: true, + wait: Nano(time.Second.Nanoseconds() / packetsPerSecond), + text: "filling tokens for single packet", + }) + + Add(RatelimiterResult{ + allowed: false, + text: "not having refilled enough", + }) + + Add(RatelimiterResult{ + allowed: true, + wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)), + text: "filling tokens for two packet burst", + }) + + Add(RatelimiterResult{ + allowed: true, + text: "second packet in 2 packet burst", + }) + + Add(RatelimiterResult{ + allowed: false, + text: "packet following 2 packet burst", + }) + + ips := []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("192.168.1.1"), + net.ParseIP("172.167.2.3"), + net.ParseIP("97.231.252.215"), + net.ParseIP("248.97.91.167"), + net.ParseIP("188.208.233.47"), + net.ParseIP("104.2.183.179"), + net.ParseIP("72.129.46.120"), + net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), + net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), + net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), + net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), + net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), + net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), + } + + ratelimiter.Init() + + for i, res := range expectedResults { + time.Sleep(res.wait) + for _, ip := range ips { + allowed := ratelimiter.Allow(ip) + if allowed != res.allowed { + t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) + } + } + } +} diff --git a/internal/tai64n/tai64n.go b/internal/tai64n/tai64n.go new file mode 100644 index 0000000..da5257c --- /dev/null +++ b/internal/tai64n/tai64n.go @@ -0,0 +1,26 @@ +package tai64n + +import ( + "bytes" + "encoding/binary" + "time" +) + +const TimestampSize = 12 +const base = uint64(4611686018427387914) + +type Timestamp [TimestampSize]byte + +func Now() Timestamp { + var tai64n Timestamp + now := time.Now() + secs := base + uint64(now.Unix()) + nano := uint32(now.UnixNano()) + binary.BigEndian.PutUint64(tai64n[:], secs) + binary.BigEndian.PutUint32(tai64n[8:], nano) + return tai64n +} + +func (t1 Timestamp) After(t2 Timestamp) bool { + return bytes.Compare(t1[:], t2[:]) > 0 +} diff --git a/internal/tai64n/tai64n_test.go b/internal/tai64n/tai64n_test.go new file mode 100644 index 0000000..389b65c --- /dev/null +++ b/internal/tai64n/tai64n_test.go @@ -0,0 +1,21 @@ +package tai64n + +import ( + "testing" + "time" +) + +/* Testing the essential property of the timestamp + * as used by WireGuard. + */ +func TestMonotonic(t *testing.T) { + old := Now() + for i := 0; i < 10000; i++ { + time.Sleep(time.Nanosecond) + next := Now() + if !next.After(old) { + t.Error("TAI64N, not monotonically increasing on nano-second scale") + } + old = next + } +} diff --git a/internal/xchacha20poly1305/xchacha20.go b/internal/xchacha20poly1305/xchacha20.go new file mode 100644 index 0000000..a6e59f0 --- /dev/null +++ b/internal/xchacha20poly1305/xchacha20.go @@ -0,0 +1,169 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +package xchacha20poly1305 + +import ( + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" +) + +func hChaCha20(out *[32]byte, nonce []byte, key *[32]byte) { + + v00 := uint32(0x61707865) + v01 := uint32(0x3320646e) + v02 := uint32(0x79622d32) + v03 := uint32(0x6b206574) + + v04 := binary.LittleEndian.Uint32(key[0:]) + v05 := binary.LittleEndian.Uint32(key[4:]) + v06 := binary.LittleEndian.Uint32(key[8:]) + v07 := binary.LittleEndian.Uint32(key[12:]) + v08 := binary.LittleEndian.Uint32(key[16:]) + v09 := binary.LittleEndian.Uint32(key[20:]) + v10 := binary.LittleEndian.Uint32(key[24:]) + v11 := binary.LittleEndian.Uint32(key[28:]) + v12 := binary.LittleEndian.Uint32(nonce[0:]) + v13 := binary.LittleEndian.Uint32(nonce[4:]) + v14 := binary.LittleEndian.Uint32(nonce[8:]) + v15 := binary.LittleEndian.Uint32(nonce[12:]) + + for i := 0; i < 20; i += 2 { + v00 += v04 + v12 ^= v00 + v12 = (v12 << 16) | (v12 >> 16) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 12) | (v04 >> 20) + v00 += v04 + v12 ^= v00 + v12 = (v12 << 8) | (v12 >> 24) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 7) | (v04 >> 25) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 16) | (v13 >> 16) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 12) | (v05 >> 20) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 8) | (v13 >> 24) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 7) | (v05 >> 25) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 16) | (v14 >> 16) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 12) | (v06 >> 20) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 8) | (v14 >> 24) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 7) | (v06 >> 25) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 16) | (v15 >> 16) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 12) | (v07 >> 20) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 8) | (v15 >> 24) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 7) | (v07 >> 25) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 16) | (v15 >> 16) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 12) | (v05 >> 20) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 8) | (v15 >> 24) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 7) | (v05 >> 25) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 16) | (v12 >> 16) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 12) | (v06 >> 20) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 8) | (v12 >> 24) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 7) | (v06 >> 25) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 16) | (v13 >> 16) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 12) | (v07 >> 20) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 8) | (v13 >> 24) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 7) | (v07 >> 25) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 16) | (v14 >> 16) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 12) | (v04 >> 20) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 8) | (v14 >> 24) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 7) | (v04 >> 25) + } + + binary.LittleEndian.PutUint32(out[0:], v00) + binary.LittleEndian.PutUint32(out[4:], v01) + binary.LittleEndian.PutUint32(out[8:], v02) + binary.LittleEndian.PutUint32(out[12:], v03) + binary.LittleEndian.PutUint32(out[16:], v12) + binary.LittleEndian.PutUint32(out[20:], v13) + binary.LittleEndian.PutUint32(out[24:], v14) + binary.LittleEndian.PutUint32(out[28:], v15) +} + +func Encrypt( + dst []byte, + nonceFull *[24]byte, + plaintext []byte, + additionalData []byte, + key *[chacha20poly1305.KeySize]byte, +) []byte { + var nonce [chacha20poly1305.NonceSize]byte + var derivedKey [chacha20poly1305.KeySize]byte + hChaCha20(&derivedKey, nonceFull[:16], key) + aead, _ := chacha20poly1305.New(derivedKey[:]) + copy(nonce[4:], nonceFull[16:]) + return aead.Seal(dst, nonce[:], plaintext, additionalData) +} + +func Decrypt( + dst []byte, + nonceFull *[24]byte, + plaintext []byte, + additionalData []byte, + key *[chacha20poly1305.KeySize]byte, +) ([]byte, error) { + var nonce [chacha20poly1305.NonceSize]byte + var derivedKey [chacha20poly1305.KeySize]byte + hChaCha20(&derivedKey, nonceFull[:16], key) + aead, _ := chacha20poly1305.New(derivedKey[:]) + copy(nonce[4:], nonceFull[16:]) + return aead.Open(dst, nonce[:], plaintext, additionalData) +} diff --git a/internal/xchacha20poly1305/xchacha20_test.go b/internal/xchacha20poly1305/xchacha20_test.go new file mode 100644 index 0000000..5d5b78f --- /dev/null +++ b/internal/xchacha20poly1305/xchacha20_test.go @@ -0,0 +1,96 @@ +package xchacha20poly1305 + +import ( + "encoding/hex" + "testing" +) + +type XChaCha20Test struct { + Nonce string + Key string + PT string + CT string +} + +func TestXChaCha20(t *testing.T) { + + tests := []XChaCha20Test{ + { + Nonce: "000000000000000000000000000000000000000000000000", + Key: "0000000000000000000000000000000000000000000000000000000000000000", + PT: "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + CT: "789e9689e5208d7fd9e1f3c5b5341f48ef18a13e418998addadd97a3693a987f8e82ecd5c1433bfed1af49750c0f1ff29c4174a05b119aa3a9e8333812e0c0feb1299c5949d895ee01dbf50f8395dd84", + }, + { + Nonce: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", + Key: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", + PT: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", + CT: "e1a046aa7f71e2af8b80b6408b2fd8d3a350278cde79c94d9efaa475e1339b3dd490127b", + }, + { + Nonce: "d9a8213e8a697508805c2c171ad54487ead9e3e02d82d5bc", + Key: "979196dbd78526f2f584f7534db3f5824d8ccfa858ca7e09bdd3656ecd36033c", + PT: "43cc6d624e451bbed952c3e071dc6c03392ce11eb14316a94b2fdc98b22fedea", + CT: "53c1e8bef2dbb8f2505ec010a7afe21d5a8e6dd8f987e4ea1a2ed5dfbc844ea400db34496fd2153526c6e87c36694200", + }, + } + + for _, test := range tests { + + nonce, err := hex.DecodeString(test.Nonce) + if err != nil { + panic(err) + } + + key, err := hex.DecodeString(test.Key) + if err != nil { + panic(err) + } + + pt, err := hex.DecodeString(test.PT) + if err != nil { + panic(err) + } + + func() { + var nonceArray [24]byte + var keyArray [32]byte + copy(nonceArray[:], nonce) + copy(keyArray[:], key) + + // test encryption + + ct := Encrypt( + nil, + &nonceArray, + pt, + nil, + &keyArray, + ) + ctHex := hex.EncodeToString(ct) + if ctHex != test.CT { + t.Fatal("encryption failed, expected:", test.CT, "got", ctHex) + } + + // test decryption + + ptp, err := Decrypt( + nil, + &nonceArray, + ct, + nil, + &keyArray, + ) + if err != nil { + t.Fatal(err) + } + + ptHex := hex.EncodeToString(ptp) + if ptHex != test.PT { + t.Fatal("decryption failed, expected:", test.PT, "got", ptHex) + } + }() + + } + +} |