Skip to content

Commit

Permalink
Added tests for new functions; removed tests for removed functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanhhb committed Jan 22, 2025
1 parent 116bf2b commit c29062a
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions tests/test_spatialpops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest
import numpy as np
from laser_core.demographics.spatialpops import distribute_population_skewed, distribute_population_tapered

class TestPopulationDistribution(unittest.TestCase):

def test_distribute_population_skewed_basic(self):
np.random.seed(42)
result = distribute_population_skewed(1000, 5, 0.3)
self.assertEqual(sum(result), 1000)
self.assertEqual(len(result), 5)
self.assertAlmostEqual(sum(result[:2]) / 1000, 0.8, delta=0.1)

def test_distribute_population_skewed_alternative(self):
np.random.seed(42)
result = distribute_population_skewed(500, 3, 0.4)
self.assertEqual(sum(result), 500)
self.assertEqual(len(result), 3)
self.assertAlmostEqual(sum(result[:1]) / 500, 0.6, delta=0.1)

def test_distribute_population_skewed_zero_nodes(self):
with self.assertRaises(ValueError):
distribute_population_skewed(1000, 0, 0.3)

def test_distribute_population_skewed_zero_population(self):
with self.assertRaises(ValueError):
distribute_population_skewed(0, 5, 0.3)
#self.assertEqual(result, [0, 0, 0, 0, 0])

def test_distribute_population_skewed_invalid_fraction(self):
with self.assertRaises(ValueError):
distribute_population_skewed(1000, 5, -0.1)
with self.assertRaises(ValueError):
distribute_population_skewed(1000, 5, 1.5)

def test_distribute_population_tapered_basic(self):
result = distribute_population_tapered(1000, 5)
self.assertEqual(sum(result), 1000)
self.assertEqual(len(result), 5)
self.assertGreater(result[0], result[1])
self.assertGreater(result[1], result[2])

def test_distribute_population_tapered_small_population(self):
result = distribute_population_tapered(10, 4)
self.assertEqual(sum(result), 10)
self.assertEqual(len(result), 4)

def test_distribute_population_tapered_equal_distribution(self):
result = distribute_population_tapered(10, 10)
self.assertEqual(sum(result), 10)
self.assertEqual(len(result), 10)
self.assertIn(0, result)

def test_distribute_population_tapered_large_nodes(self):
result = distribute_population_tapered(100, 50)
self.assertEqual(sum(result), 100)
self.assertEqual(len(result), 50)
self.assertGreater(result[0], result[-1])

def test_distribute_population_tapered_zero_nodes(self):
with self.assertRaises(ValueError):
distribute_population_tapered(1000, 0)

def test_distribute_population_tapered_zero_population(self):
with self.assertRaises(ValueError):
result = distribute_population_tapered(0, 5)
#self.assertEqual(sum(result), 0)
#self.assertEqual(len(result), 5)
#self.assertTrue(all(v == 0 for v in result))

def test_distribute_population_tapered_adjustment(self):
result = distribute_population_tapered(1200, 3)
self.assertEqual(sum(result), 1200)
self.assertEqual(len(result), 3)

if __name__ == '__main__':
unittest.main()

0 comments on commit c29062a

Please sign in to comment.