Skip to content

Commit

Permalink
fix include sort order and formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <[email protected]>
  • Loading branch information
sylvlecl committed Jul 27, 2024
1 parent e9e0e74 commit f992ac9
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/cluster_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pathlib import Path
from typing import Iterable, Optional, TextIO, Tuple

import pandas as pd
from pydantic import BaseModel, Field
from yaml import safe_load
Expand Down
10 changes: 7 additions & 3 deletions src/duration_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# This file is part of the Antares project.
from abc import ABC, abstractmethod
from enum import Enum
from math import log, sqrt
from typing import List

import numpy as np
from math import log, sqrt

from random_generator import RNG

Expand All @@ -39,7 +40,9 @@ class GeneratorWrapper(DurationGenerator):
Used to keep backward compat with cpp implementation.
"""

def __init__(self, delegate: DurationGenerator, volatility: float, expecs: List[int]) -> None:
def __init__(
self, delegate: DurationGenerator, volatility: float, expecs: List[int]
) -> None:
self.volatility = volatility
self.expectations = expecs
self.delegate = delegate
Expand Down Expand Up @@ -97,11 +100,12 @@ def generate_duration(self, day: int) -> int:


def make_duration_generator(
rng: RNG, law: ProbilityLaw, volatility: float, expecs: List[int]
rng: RNG, law: ProbilityLaw, volatility: float, expecs: List[int]
) -> DurationGenerator:
"""
return a DurationGenerator for the given law
"""
base_rng: DurationGenerator
if law == ProbilityLaw.UNIFORM:
base_rng = UniformDurationGenerator(rng, volatility, expecs)
elif law == ProbilityLaw.GEOMETRIC:
Expand Down
5 changes: 3 additions & 2 deletions src/mersenne_twister.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dataclasses import field
from typing import List, Tuple


@dataclasses.dataclass
class MersenneTwister:
periodN: int = 624
Expand All @@ -40,12 +41,12 @@ class MersenneTwister:
LOWER_MASK: int = 0x7FFFFFFF

MAG: Tuple[int, int] = (0, MATRIX_A)

mt: List[int] = field(default_factory=lambda: [0] * 624)

mti: int = 0

def seed(self, seed:int) -> None:
def seed(self, seed: int) -> None:
self.mt[0] = seed & 0xFFFFFFFF
for i in range(1, self.periodN):
self.mt[i] = 1812433253 * (self.mt[i - 1] ^ (self.mt[i - 1] >> 30)) + i
Expand Down
8 changes: 4 additions & 4 deletions src/random_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class RNG(ABC):
"""

@abstractmethod
def next(self) -> int:
def next(self) -> float:
...


Expand All @@ -31,7 +31,7 @@ class PythonRNG(ABC):
Native python RNG.
"""

def next(self) -> int:
def next(self) -> float:
return random.random()


Expand All @@ -40,9 +40,9 @@ class MersenneTwisterRNG(RNG):
Our own RNG based on Mersenne-Twister algorithm.
"""

def __init__(self, seed: Optional[int] = 5489):
def __init__(self, seed: int = 5489):
self._rng = MersenneTwister()
self._rng.seed(seed)

def next(self) -> int:
def next(self) -> float:
return self._rng.next()
9 changes: 5 additions & 4 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
import cProfile
import random
from pstats import SortKey

import pytest
import cProfile

from mersenne_twister import MersenneTwister
from random_generator import MersenneTwisterRNG
Expand All @@ -40,8 +40,8 @@ def cluster() -> ThermalCluster:
npo_max=[1 for i in range(NB_OF_DAY)],
)

def test_rng():

def test_rng():
random = MersenneTwister()
random.reset()
for _ in range(100):
Expand All @@ -54,6 +54,7 @@ def test_random():
for i in range(5):
print(rng.next())


def test_performances():
with cProfile.Profile() as pr:
days = 365
Expand All @@ -79,7 +80,6 @@ def test_performances():


def test_compare_with_simulator():

days = 365
cluster = ThermalCluster(
unit_count=10,
Expand All @@ -99,10 +99,11 @@ def test_compare_with_simulator():

generator = ThermalDataGenerator(rng=MersenneTwisterRNG(), days_per_year=days)
results = generator.generate_time_series(cluster, 1)
for i in range(365*24):
for i in range(365 * 24):
print(str(i) + " : " + str(results.available_power[0][i]))
print(results.available_power[0])


def test_ts_value(cluster):
ts_nb = 4

Expand Down

0 comments on commit f992ac9

Please sign in to comment.