diff options
Diffstat (limited to 'tai64n')
-rw-r--r-- | tai64n/tai64n.go | 11 | ||||
-rw-r--r-- | tai64n/tai64n_test.go | 40 |
2 files changed, 32 insertions, 19 deletions
diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go index fb32d0c..2838f4f 100644 --- a/tai64n/tai64n.go +++ b/tai64n/tai64n.go @@ -17,16 +17,19 @@ const whitenerMask = uint32(0x1000000 - 1) type Timestamp [TimestampSize]byte -func Now() Timestamp { +func stamp(t time.Time) Timestamp { var tai64n Timestamp - now := time.Now() - secs := base + uint64(now.Unix()) - nano := uint32(now.Nanosecond()) &^ whitenerMask + secs := base + uint64(t.Unix()) + nano := uint32(t.Nanosecond()) &^ whitenerMask binary.BigEndian.PutUint64(tai64n[:], secs) binary.BigEndian.PutUint32(tai64n[8:], nano) return tai64n } +func Now() Timestamp { + return stamp(time.Now()) +} + func (t1 Timestamp) After(t2 Timestamp) bool { return bytes.Compare(t1[:], t2[:]) > 0 } diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go index 05a9d8f..6df7367 100644 --- a/tai64n/tai64n_test.go +++ b/tai64n/tai64n_test.go @@ -10,21 +10,31 @@ import ( "time" ) -/* Testing the essential property of the timestamp - * as used by WireGuard. - */ +// Test that timestamps are monotonic as required by Wireguard and that +// nanosecond-level information is whitened to prevent side channel attacks. func TestMonotonic(t *testing.T) { - old := Now() - for i := 0; i < 50; i++ { - next := Now() - if next.After(old) { - t.Error("Whitening insufficient") - } - time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1) - next = Now() - if !next.After(old) { - t.Error("Not monotonically increasing on whitened nano-second scale") - } - old = next + startTime := time.Unix(0, 123456789) // a nontrivial bit pattern + // Whitening should reduce timestamp granularity + // to more than 10 but fewer than 20 milliseconds. + tests := []struct { + name string + t1, t2 time.Time + wantAfter bool + }{ + {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, + {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, + {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, + {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, + {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts1, ts2 := stamp(tt.t1), stamp(tt.t2) + got := ts2.After(ts1) + if got != tt.wantAfter { + t.Errorf("after = %v; want %v", got, tt.wantAfter) + } + }) } } |