Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve machine_id ergonomics #9

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Black Check
uses: jpetrucciani/black-check@20.8b1
uses: jpetrucciani/black-check@24.8.0
- name: python-isort
uses: isort/isort-action@v0.1.0
uses: isort/isort-action@v1
with:
isortVersion: 5.7.0
isort-version: 5.13.2
configuration: --profile black --diff --check-only
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ The generator can be configured with variety of options, such as
custom `machine_id`, `start_time` etc.

- `start_time` should be an instance of `datetime.datetime`.
- `machine_id` should be a callable which returns an integer value
upto 16-bits.
- `machine_id` should be an integer value upto 16-bits, callable or
`None` (will be used random machine id).

## License

Expand Down
89 changes: 48 additions & 41 deletions sonyflake/sonyflake.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
import datetime
import ipaddress
from functools import partial
from random import randrange
from socket import gethostbyname, gethostname
from threading import Lock
from time import sleep
from typing import Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union
from warnings import warn

BIT_LEN_TIME = 39
BIT_LEN_SEQUENCE = 8
BIT_LEN_MACHINE_ID = 63 - (BIT_LEN_TIME + BIT_LEN_SEQUENCE)
MAX_MACHINE_ID = (1 << BIT_LEN_MACHINE_ID) - 1
UTC = datetime.timezone.utc
SONYFLAKE_EPOCH = datetime.datetime(2014, 9, 1, 0, 0, 0, tzinfo=UTC)
random_machine_id = partial(randrange, 0, MAX_MACHINE_ID + 1)
random_machine_id.__doc__ = "Returns a random machine ID."
utc_now = partial(datetime.datetime.now, tz=UTC)


def lower_16bit_private_ip() -> int:
"""
Returns the lower 16 bits of the private IP address.
"""
ip: ipaddress.IPv4Address = ipaddress.ip_address(gethostbyname(gethostname()))
ip = ipaddress.ip_address(gethostbyname(gethostname()))
ip_bytes = ip.packed
return (ip_bytes[2] << 8) + ip_bytes[3]

Expand All @@ -29,29 +36,21 @@ class SonyFlake:
_start_time: int
_machine_id: int

def __new__(
cls,
start_time: Optional[datetime.datetime] = None,
machine_id: Optional[Callable[[], int]] = None,
check_machine_id: Optional[Callable[[int], bool]] = None,
):
if start_time and datetime.datetime.now(UTC) < start_time:
return None
instance = super().__new__(cls)
if machine_id is not None:
instance._machine_id = machine_id()
else:
instance._machine_id = lower_16bit_private_ip()
if check_machine_id is not None:
if not check_machine_id(instance._machine_id):
return None
return instance
__slots__ = (
"_now",
"mutex",
"_start_time",
"_machine_id",
"elapsed_time",
"sequence",
)

def __init__(
self,
start_time: Optional[datetime.datetime] = None,
machine_id: Optional[Callable[[], int]] = None,
check_machine_id: Optional[Callable[[int], bool]] = None,
machine_id: Union[None, int, Callable[[], int]] = None,
check_machine_id: Any = None,
now: Callable[[], datetime.datetime] = utc_now,
) -> None:
"""
Create a new instance of `SonyFlake` unique ID generator.
Expand All @@ -65,30 +64,40 @@ def __init__(
* If `start_time` is ahead of the current time, SonyFlake is
not created.

`machine_id` returns the unique ID of the SonyFlake instance.

* If `machine_id` returns an error, SonyFlake is not created.
`machine_id` a unique ID of the SonyFlake instance in range [0x0000, 0xFFFF].

* If `machine_id` is nil, default `machine_id` is used.
* If `machine_id` is an integer, it is used as is.

* Default `machine_id` returns the lower 16 bits of the
private IP address.
* If `machine_id` is a callable, it is called to get the machine ID.

`check_machine_id` validates the uniqueness of the machine ID.

* If `check_machine_id` returns `False`, SonyFlake is not
created.

* If `check_machine_id` is `None`, no validation is done.
* Otherwise, a random machine ID is generated.
"""

if start_time is None:
start_time = SONYFLAKE_EPOCH

if now() < start_time:
raise ValueError("start_time cannot be in future")

if machine_id is None:
_machine_id = random_machine_id()
elif callable(machine_id):
_machine_id = machine_id()
else:
_machine_id = machine_id

if not (0 <= _machine_id <= MAX_MACHINE_ID):
raise ValueError("machine_id must be in range [0x0000, 0xFFFF]")

if check_machine_id is not None:
warn("check_machine_id is deprecated", DeprecationWarning)

self.mutex = Lock()
self._now = now
self._machine_id = _machine_id
self._start_time = self.to_sonyflake_time(start_time)
self.elapsed_time = self.current_elapsed_time()
self.sequence = (1 << BIT_LEN_SEQUENCE) - 1
if not hasattr(self, "_machine_id"):
self._machine_id = machine_id and machine_id() or lower_16bit_private_ip()

@staticmethod
def to_sonyflake_time(given_time: datetime.datetime) -> int:
Expand All @@ -110,7 +119,7 @@ def current_time(self) -> int:
"""
Get current UTC time in the SonyFlake's time value.
"""
return self.to_sonyflake_time(datetime.datetime.now(UTC))
return self.to_sonyflake_time(self._now())

def current_elapsed_time(self) -> int:
"""
Expand All @@ -136,7 +145,7 @@ def next_id(self) -> int:
if self.sequence == 0:
self.elapsed_time += 1
overtime = self.elapsed_time - current_time
sleep(self.sleep_time(overtime))
sleep(self.sleep_time(overtime, self._now()))
return self.to_id()

def to_id(self) -> int:
Expand All @@ -147,13 +156,11 @@ def to_id(self) -> int:
return time | sequence | self.machine_id

@staticmethod
def sleep_time(duration: int) -> float:
def sleep_time(duration: int, now: datetime.datetime) -> float:
"""
Calculate the time remaining until generation of new ID.
"""
return (
duration * 10 - (datetime.datetime.now(UTC).timestamp() * 100) % 1
) / 100
return (duration * 10 - (now.timestamp() * 100) % 1) / 100

@staticmethod
def decompose(_id: int) -> Dict[str, int]:
Expand Down
42 changes: 27 additions & 15 deletions tests/test_sonyflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
from time import sleep
from unittest import TestCase

from pytest import raises

from sonyflake.sonyflake import (
BIT_LEN_SEQUENCE,
SONYFLAKE_EPOCH,
SonyFlake,
lower_16bit_private_ip,
random_machine_id,
)


class SonyFlakeTestCase(TestCase):
def setUp(self):
start_time = datetime.now(timezone.utc)
self.sf = SonyFlake(start_time)
self.machine_id = 0x7F7F
self.sf = SonyFlake(start_time, machine_id=self.machine_id)
self.start_time = SonyFlake.to_sonyflake_time(start_time)
self.machine_id = lower_16bit_private_ip()

@staticmethod
def _current_time():
Expand All @@ -28,7 +31,10 @@ def _sleep(duration):
return sleep(duration / 100)

def test_sonyflake_epoch(self):
sf = SonyFlake(start_time=SONYFLAKE_EPOCH)
sf = SonyFlake(
start_time=SONYFLAKE_EPOCH,
machine_id=self.machine_id,
)
self.assertEqual(sf.start_time, 140952960000)
next_id = sf.next_id()
parts = SonyFlake.decompose(next_id)
Expand All @@ -37,12 +43,9 @@ def test_sonyflake_epoch(self):
self.assertEqual(parts["sequence"], 0)

def test_sonyflake_custom_machine_id(self):
machine_id = randint(1, 255 ** 2)

def get_machine_id():
return machine_id
machine_id = randint(1, 255**2)

sf = SonyFlake(machine_id=get_machine_id)
sf = SonyFlake(machine_id=machine_id)
next_id = sf.next_id()
parts = SonyFlake.decompose(next_id)
self.assertEqual(parts["machine_id"], machine_id)
Expand All @@ -59,15 +62,16 @@ def test_sonyflake_once(self):

def test_sonyflake_future(self):
future_start_time = datetime.now(timezone.utc) + timedelta(minutes=1)
sonyflake = SonyFlake(start_time=future_start_time)
self.assertIsNone(sonyflake, "SonyFlake starting in the future")

def test_sonyflake_invalid_machine_id(self):
def check_machine_id(_: int) -> bool:
return False
with raises(ValueError, match=r"start_time cannot be in future"):
SonyFlake(start_time=future_start_time)

sonyflake = SonyFlake(check_machine_id=check_machine_id)
self.assertIsNone(sonyflake, "Machine ID check failed")
def test_sonyflake_invalid_machine_id(self):
for machine_id in [-1, 0xFFFF + 1]:
with raises(
ValueError, match=r"machine_id must be in range \[0x0000, 0xFFFF\]"
):
SonyFlake(machine_id=machine_id)

def test_sonyflake_for_10sec(self):
last_id = 0
Expand Down Expand Up @@ -100,3 +104,11 @@ def test_sonyflake_in_parallel(self):
result_set = set(results)
self.assertEqual(len(results), len(result_set))
self.assertCountEqual(results, result_set)


def test_random_machine_id() -> None:
assert random_machine_id()


def test_lower_16bit_private_ip() -> None:
assert lower_16bit_private_ip()