Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Aug 15, 2023
1 parent a098f58 commit 5ec4fbe
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 33 deletions.
34 changes: 1 addition & 33 deletions ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .checkpoint import peek
from .inputs import get_input
from .outputs import get_output
from .stepper import Stepper

LOG = logging.getLogger(__name__)

Expand All @@ -37,39 +38,6 @@ def __exit__(self, *args):
LOG.info("%s: %s.", self.title, seconds(elapsed))


class Stepper:
def __init__(self, step, lead_time):
self.step = step
self.lead_time = lead_time
self.start = time.time()
self.last = self.start
self.num_steps = lead_time // step
LOG.info("Starting inference for %s steps (%sh).", self.num_steps, lead_time)

def __enter__(self):
return self

def __call__(self, i, step):
now = time.time()
elapsed = now - self.start
speed = (i + 1) / elapsed
eta = (self.num_steps - i) / speed
LOG.info(
"Done %s out of %s in %s (%sh), ETA: %s.",
i + 1,
self.num_steps,
seconds(now - self.last),
step,
seconds(eta),
)
self.last = now

def __exit__(self, *args):
elapsed = time.time() - self.start
LOG.info("Elapsed: %s.", seconds(elapsed))
LOG.info("Average: %s per step.", seconds(elapsed / self.num_steps))


class ArchiveCollector:
def __init__(self) -> None:
self.expect = 0
Expand Down
46 changes: 46 additions & 0 deletions ai_models/stepper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging
import time

from climetlab.utils.humanize import seconds

LOG = logging.getLogger(__name__)


class Stepper:
def __init__(self, step, lead_time):
self.step = step
self.lead_time = lead_time
self.start = time.time()
self.last = self.start
self.num_steps = lead_time // step
LOG.info("Starting inference for %s steps (%sh).", self.num_steps, lead_time)

def __enter__(self):
return self

def __call__(self, i, step):
now = time.time()
elapsed = now - self.start
speed = (i + 1) / elapsed
eta = (self.num_steps - i) / speed
LOG.info(
"Done %s out of %s in %s (%sh), ETA: %s.",
i + 1,
self.num_steps,
seconds(now - self.last),
step,
seconds(eta),
)
self.last = now

def __exit__(self, *args):
elapsed = time.time() - self.start
LOG.info("Elapsed: %s.", seconds(elapsed))
LOG.info("Average: %s per step.", seconds(elapsed / self.num_steps))

0 comments on commit 5ec4fbe

Please sign in to comment.