Skip to content

Commit

Permalink
add: pickletracer in utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Secbone committed Nov 22, 2023
1 parent 0e1332d commit 039dd88
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Add
- Added `performance` in `toad.utils` for test code performance
- Added `pickletracer` in `toad.utils` for infer requirements in pickle object

## [0.1.2] - 2023-04-09

Expand Down
140 changes: 140 additions & 0 deletions toad/utils/pickletracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import cloudpickle
from pickle import Unpickler
from cloudpickle import CloudPickler

_global_tracer = None

def get_current_tracer():
global _global_tracer
# if _global_tracer is None:
# raise ValueError("tracer is not initialized")
return _global_tracer


class Unpickler(Unpickler):
"""trace object dependences during unpickle"""
def find_class(self, module, name):
tracer = get_current_tracer()
tracer.add(module)
return super().find_class(module, name)


class Pickler(CloudPickler):
"""trace object dependences during pickle"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

import types
self._reduce_module = CloudPickler.dispatch_table[types.ModuleType]
self.dispatch_table[types.ModuleType] = self.reduce_module


def reduce_module(self, obj):
tracer = get_current_tracer()
tracer.add(obj.__name__)
return self._reduce_module(obj)


def __setattr__(self, name, value):
if name == 'persistent_id':
# fix torch module
def wrapper_func(obj):
from torch.nn import Module
if isinstance(obj, Module):
return None

return value(obj)

return super().__setattr__(name, wrapper_func)

return super().__setattr__(name, value)


class Tracer:
def __init__(self):
import re

self._modules = set()
self._ignore_modules = {"builtins"}
self._temp_dispatch_table = {}

# match python site packages path
self._regex = re.compile(r".*python[\d\.]+\/site-packages/[\w-]+")

def add(self, module):
root = module.split(".")[0]

if root in self._ignore_modules:
return

self._modules.add(root)

def trace(self, obj):
"""trace `obj` by picke and unpicke
"""
import io
dummy = io.BytesIO()

with self:
Pickler(dummy).dump(obj)
dummy.seek(0)
Unpickler(dummy).load()

return self.get_deps()


def get_deps(self):
import sys

deps = {
"pip": [],
"files": [],
}

for name in self._modules:
if name not in sys.modules:
# TODO: should raise error
continue

module = sys.modules[name]
# package module
if self._regex.match(module.__spec__.origin):
# TODO: spilt pip and conde pkg
deps["pip"].append(module)
continue

# local file module
deps["files"].append(module)

return deps


def __enter__(self):
global _global_tracer
if _global_tracer is not None:
raise ValueError("a tracer is already exists")

# save the Cloudpickler global dispatch table
self._temp_dispatch_table = CloudPickler.dispatch_table.copy()
# setup the global tracer
_global_tracer = self
return self

def __exit__(self, exc_type, exc_val, exc_tb):
global _global_tracer

# restore the dispatch table to Cloudpickler
CloudPickler.dispatch_table = self._temp_dispatch_table
# clean the global tracer
_global_tracer = None




def dump(obj, file, *args, **kwargs):
return Pickler(file).dump(obj)


def load(file, *args, **kwargs):
return Unpickler(file).load()

61 changes: 61 additions & 0 deletions toad/utils/pickletracer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

from .pickletracer import Tracer, get_current_tracer


def test_tracer_with_clause():
assert get_current_tracer() is None
with Tracer() as t:
assert get_current_tracer() == t

assert get_current_tracer() is None


def test_trace_pyfunc():
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
# y = 1 * x_0 + 2 * x_1 + 3
y = np.dot(X, np.array([1, 2])) + 3
reg = LinearRegression().fit(X, y)
reg.score(X, y)

def func(data):
# data = dfunc(data)
df = pd.DataFrame(data)
return df

class Model:
def __init__(self, model, pref):
self.model = model
self.pref = pref

def predict(self, data):
data = self.pref(data)
return self.model.predict(data)


m = Model(reg, func)

deps = Tracer().trace(m)

assert set([m.__name__ for m in deps['pip']]) == set(['numpy', 'pandas', 'cloudpickle', 'sklearn'])


def test_default_cloudpickle():
import pandas as pd

def func(data):
# data = dfunc(data)
df = pd.DataFrame(data)
return df

deps = Tracer().trace(func)

import io
import cloudpickle

dummy = io.BytesIO()
# this should be correct after trace object
# test for restore cloudpickle global dispatch table
cloudpickle.dump(func, dummy)

0 comments on commit 039dd88

Please sign in to comment.