Skip to content

Commit

Permalink
avoid hard dependency on prettytable;
Browse files Browse the repository at this point in the history
expose strictness option for tracing
  • Loading branch information
amakelov committed Aug 11, 2024
1 parent c870aaf commit f5e8fce
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 10 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
mandala.egg-info/*

scratchpad/
mandala/scipy/

# visualizations
# Ignore files cached by Hypothesis
Expand All @@ -77,9 +76,6 @@ mandala/tests/output/*
*.db
*.parquet

## PyCharm stuff
.idea/**

# ignore docs build
site/

Expand Down
13 changes: 11 additions & 2 deletions mandala/cf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .common_imports import *
from .common_imports import sess
from .config import Config
import textwrap
import pprint
from .utils import (
Expand All @@ -20,6 +21,10 @@

from .viz import Node, Edge, SOLARIZED_LIGHT, to_dot_string, write_output

if Config.has_prettytable:
import prettytable
from io import StringIO


def get_name_proj(op: Op) -> Callable[[str], str]:
if op.name == __make_list__.name:
Expand Down Expand Up @@ -2340,8 +2345,12 @@ def get_func_stats(self) -> pd.DataFrame:
return pd.DataFrame(rows)

def _get_prettytable_str(self, df: pd.DataFrame) -> str:
import prettytable
from io import StringIO
if not Config.has_prettytable:
# fallback
logger.info(
"Install the `prettytable` package to get a prettier output for the `info` method."
)
return str(df)

output = StringIO()
df.to_csv(output, index=False)
Expand Down
7 changes: 7 additions & 0 deletions mandala/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class Config:
has_rich = True
except ImportError:
has_rich = False

try:
import prettytable

has_prettytable = True
except ImportError:
has_prettytable = False


if Config.has_torch:
Expand Down
8 changes: 7 additions & 1 deletion mandala/deps/tracers/dec_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def register_call(self, func: Callable) -> CallableNode:
)
if len(closure_names) > 0:
msg = f"Found closure variables accessed by function {module_name}.{qualname}:\n{closure_names}"
raise ValueError(msg)
self._process_failure(msg)
### get call node
node = CallableNode.from_runtime(
module_name=module_name, obj_name=qualname, code_obj=extract_code(obj=func)
Expand Down Expand Up @@ -269,3 +269,9 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
DecTracer.set_active_trace_obj(None)

def _process_failure(self, msg: str):
if self.strict:
raise RuntimeError(msg)
else:
logger.warning(msg)
4 changes: 2 additions & 2 deletions mandala/storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .common_imports import *
from tqdm import tqdm
import prettytable
import datetime
from .model import *
import sqlite3
Expand All @@ -26,6 +25,7 @@ class Storage:
def __init__(self, db_path: str = ":memory:",
deps_path: Optional[Union[str, Path]] = None,
tracer_impl: Optional[type] = None,
strict_tracing: bool = True,
deps_package: Optional[str] = None,
):
self.db = DBAdapter(db_path=db_path)
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, db_path: str = ":memory:",
versioner = Versioner(
paths=roots,
TracerCls=DecTracer if tracer_impl is None else tracer_impl,
strict=True,
strict=strict_tracing,
track_methods=True,
package_name=deps_package,
)
Expand Down
9 changes: 8 additions & 1 deletion mandala/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
import joblib
import io
import inspect
import prettytable
import sqlite3
from .config import *
from abc import ABC, abstractmethod
from typing import Hashable, TypeVar
if Config.has_prettytable:
import prettytable

def dataframe_to_prettytable(df: pd.DataFrame) -> str:
if not Config.has_prettytable:
# fallback to pandas printing
logger.info(
"Install the 'prettytable' package to get prettier tables in the console."
)
return df.to_string()
# Initialize a PrettyTable object
table = prettytable.PrettyTable()

Expand Down

0 comments on commit f5e8fce

Please sign in to comment.