From 4f798e10f2476e65011493962be5aa6b3d9008eb Mon Sep 17 00:00:00 2001 From: j2gg0s Date: Wed, 19 Oct 2022 15:02:41 +0800 Subject: [PATCH] feat: allow user custome time precision. --- README.md | 7 ++-- snowflake.go | 60 ++++++++++++++++++++++++-------- snowflake_test.go | 89 ++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 125 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 23af210..14038af 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,10 @@ By default, the ID format follows the original Twitter snowflake format. ### Custom Format You can alter the number of bits used for the node id and step number (sequence) -by setting the snowflake.NodeBits and snowflake.StepBits values. Remember that -There is a maximum of 22 bits available that can be shared between these two -values. You do not have to use all 22 bits. +by setting the snowflake.NodeBits and snowflake.StepBits values. You can alter +time precision by setting the snowflake.TimePrecision value. Remember that +There is a maximum of 22 bits available that can be shared between these two values. +You do not have to use all 22 bits. ### Custom Epoch By default this package uses the Twitter Epoch of 1288834974657 or Nov 04 2010 01:42:54. diff --git a/snowflake.go b/snowflake.go index 588a09f..6d9aca6 100644 --- a/snowflake.go +++ b/snowflake.go @@ -24,6 +24,9 @@ var ( // Remember, you have a total 22 bits to share between Node/Step StepBits uint8 = 12 + // TimePrecision define the precision of timestamp, default is millisecond. + TimePrecision time.Duration = time.Millisecond + // DEPRECATED: the below four variables will be removed in a future release. mu sync.Mutex nodeMax int64 = -1 ^ (-1 << NodeBits) @@ -89,6 +92,14 @@ type Node struct { stepMask int64 timeShift uint8 nodeShift uint8 + + timePrecision time.Duration + + // avoid new slice when generate single id + singleIDSlice []ID + + // for unit test + sinceFn func(time.Time) time.Duration } // An ID is a custom type used for a snowflake ID. This is used so we can @@ -116,6 +127,10 @@ func NewNode(node int64) (*Node, error) { n.stepMask = -1 ^ (-1 << StepBits) n.timeShift = NodeBits + StepBits n.nodeShift = StepBits + n.timePrecision = TimePrecision + n.singleIDSlice = make([]ID, 1) + + n.sinceFn = time.Since if n.node < 0 || n.node > n.nodeMax { return nil, errors.New("Node number must be between 0 and " + strconv.FormatInt(n.nodeMax, 10)) @@ -133,32 +148,47 @@ func NewNode(node int64) (*Node, error) { // - Make sure your system is keeping accurate system time // - Make sure you never have multiple nodes running with the same node ID func (n *Node) Generate() ID { + return n.generate(1)[0] +} + +func (n *Node) GenerateMany(num int) []ID { + return n.generate(num) +} +func (n *Node) generate(num int) []ID { n.mu.Lock() defer n.mu.Unlock() - now := time.Since(n.epoch).Milliseconds() + i := 0 + ids := n.singleIDSlice + if num > 1 { + ids = make([]ID, num) + } + + now := int64(n.sinceFn(n.epoch) / n.timePrecision) - if now == n.time { - n.step = (n.step + 1) & n.stepMask + if now > n.time { + n.step = 0 + } else if now < n.time { + now = n.time + } - if n.step == 0 { + for i < num { + for n.step <= n.stepMask && i < num { + ids[i] = ID((now)<> timeShift) + Epoch + return (int64(f)>>timeShift)*int64(TimePrecision/time.Millisecond) + Epoch } // Node returns an int64 of the snowflake ID node number diff --git a/snowflake_test.go b/snowflake_test.go index ff750c4..4b1dacd 100644 --- a/snowflake_test.go +++ b/snowflake_test.go @@ -2,8 +2,10 @@ package snowflake import ( "bytes" + "fmt" "reflect" "testing" + "time" ) //****************************************************************************** @@ -26,7 +28,6 @@ func TestNewNode(t *testing.T) { // lazy check if Generate will create duplicate IDs // would be good to later enhance this with more smarts func TestGenerateDuplicateID(t *testing.T) { - node, _ := NewNode(1) var x, y ID @@ -39,6 +40,48 @@ func TestGenerateDuplicateID(t *testing.T) { } } +func TestGenerate(t *testing.T) { + for _, nodeID := range []int64{0, 1, 12, nodeMax} { + node, err := NewNode(nodeID) + if err != nil { + t.Fatalf("%d error creating NewNode, %s", nodeID, err) + } + + now := time.Since(node.epoch) + cnt := int64(0) + node.sinceFn = func(time.Time) time.Duration { + if cnt == node.stepMask+1 { + cnt = 0 + now += node.timePrecision + } else { + cnt += 1 + } + return now + } + + st := now + for i := int64(0); i < 2*node.stepMask; i++ { + if i != 0 && (i&node.stepMask) == 0 { + st += node.timePrecision + } + + stamp := int64(st / node.timePrecision) + step := i & node.stepMask + + id := node.Generate() + if id.Node() != nodeID { + t.Fatalf("%d/%d expected node %d, got %d", nodeID, i, nodeID, id.Node()) + } + if id.Step() != step { + t.Fatalf("%d/%d expected step %d, got %d", nodeID, i, step, id.Step()) + } + if id.Time()-Epoch != stamp { + t.Fatalf("%d/%d expected time %d, got %d", nodeID, i, stamp, id.Time()-Epoch) + } + } + } +} + // I feel like there's probably a better way func TestRace(t *testing.T) { @@ -47,7 +90,7 @@ func TestRace(t *testing.T) { go func() { for i := 0; i < 1000000000; i++ { - NewNode(1) + _, _ = NewNode(1) } }() @@ -382,7 +425,7 @@ func BenchmarkParseBase32(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - ParseBase32([]byte(b32i)) + _, _ = ParseBase32([]byte(b32i)) } } func BenchmarkBase32(b *testing.B) { @@ -407,11 +450,10 @@ func BenchmarkParseBase58(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - ParseBase58([]byte(b58)) + _, _ = ParseBase58([]byte(b58)) } } func BenchmarkBase58(b *testing.B) { - node, _ := NewNode(1) sf := node.Generate() @@ -422,29 +464,50 @@ func BenchmarkBase58(b *testing.B) { sf.Base58() } } -func BenchmarkGenerate(b *testing.B) { +func BenchmarkGenerate(b *testing.B) { node, _ := NewNode(1) + now := time.Millisecond + node.sinceFn = func(time.Time) time.Duration { + now += time.Millisecond + return now + } b.ReportAllocs() - b.ResetTimer() + for n := 0; n < b.N; n++ { _ = node.Generate() } } -func BenchmarkGenerateMaxSequence(b *testing.B) { +func BenchmarkGenerateMany(b *testing.B) { + for _, num := range []int{10, 30, 100} { + b.Run(fmt.Sprintf("%d", num), func(b *testing.B) { + node, _ := NewNode(int64(num)) + now := time.Millisecond + node.sinceFn = func(time.Time) time.Duration { + now += time.Millisecond + return now + } - NodeBits = 1 - StepBits = 21 - node, _ := NewNode(1) + b.ReportAllocs() + b.ResetTimer() - b.ReportAllocs() + for n := 0; n < b.N; n++ { + _ = node.GenerateMany(num) + } + }) + } +} + +func BenchmarkSince(b *testing.B) { + epoch := time.Unix(Epoch/1000, (Epoch%1000)*1000000) + b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { - _ = node.Generate() + time.Since(epoch) } }