From 9bfb1d9bb208944eeb8dcf3c08e4617cb3ab84a4 Mon Sep 17 00:00:00 2001 From: Joachim Folz Date: Sun, 13 Oct 2024 12:06:42 +0200 Subject: [PATCH] fix prime=0 selected for domain=1 --- CHANGELOG.md | 1 + shufflish/__init__.py | 11 ++++++++++- test/test_completeness.py | 7 ++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6258c89..0e58411 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Fixed - Very large seeds no longer cause integer overflow +- 0 is no longer selected as coprime for domain=1 ## [0.0.2] - 2024-10-12 diff --git a/shufflish/__init__.py b/shufflish/__init__.py index 0a63986..4d30233 100644 --- a/shufflish/__init__.py +++ b/shufflish/__init__.py @@ -79,6 +79,9 @@ def _modular_prime_combinations(domain, primes, k): Generate all ``k``-combinations of the given primes that are unique mod ``domain``. Only considers primes that are coprime with ``domain``. """ + if domain == 1: + yield 1 + return primes = list(dict.fromkeys(p % domain for p in primes if domain % p != 0)) seen = set() ones = (1,) * (k-1) @@ -96,6 +99,9 @@ def _modular_prime_combinations_with_repetition(domain, primes, k): Only considers primes that are coprime with ``domain``. May repeat values. """ + if domain == 1: + yield 1 + return ones = (1,) * (k-1) primes = list(dict.fromkeys(p % domain for p in primes if domain % p != 0)) for p1, p2, p3 in combinations(chain(ones, primes), k): @@ -196,6 +202,8 @@ def _select_prime( Only considers primes that are coprime with ``domain``. This can be quite slow. """ + if domain == 1: + return 1 gen = _modular_prime_combinations(domain, primes, k) num_comb = None if primes is PRIMES and domain in NUM_COMBINATIONS: @@ -221,6 +229,8 @@ def _select_prime_with_repetition( Return the product of this combination mod domain. This is reasonably fast, but does not account for reptitions mod domain. """ + if domain == 1: + return 1 ones = (1,) * (k-1) primes = list(chain(ones, dict.fromkeys(p % domain for p in primes if domain % p != 0))) np = len(primes) @@ -302,7 +312,6 @@ def permutation( raise ValueError("domain must be < 2**63") if seed is None: seed = random.randrange(2**64) - # Step 1: Select coprime number if allow_repetition: prime = _select_prime_with_repetition(domain, seed, primes, num_primes) else: diff --git a/test/test_completeness.py b/test/test_completeness.py index 72bfcd3..ef4c3eb 100644 --- a/test/test_completeness.py +++ b/test/test_completeness.py @@ -9,17 +9,18 @@ def _is_complete(p, domain): assert len(set(p)) == domain, (p, domain) -def test_permutations_class(): +def test_permutations_function(): for domain in (1, 2, 3, 5, 7, 10, 11, 13, 100): for seed in range(domain): _is_complete(permutation(domain, seed), domain) -def test_permutation_function(): +def test_permutation_class(): for domain in (1, 2, 3, 5, 7, 10, 11, 13, 100): perms = Permutations(domain) + print(perms.coprimes) assert len(perms.coprimes) > 0, domain - for seed in range(domain): + for seed in range(domain * len(perms.coprimes)): _is_complete(perms.get(seed), domain)