Skip to content

Commit

Permalink
blackify
Browse files Browse the repository at this point in the history
  • Loading branch information
Mittmich committed Oct 7, 2023
1 parent 604a648 commit 39bcc93
Show file tree
Hide file tree
Showing 15 changed files with 907 additions and 537 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"bioframe==0.3.3",
"sparse==0.13.0",
"multiprocess>=0.70.13",
"numba>=0.57.0"
"numba>=0.57.0",
]

test_requirements = [
Expand Down
8 changes: 2 additions & 6 deletions spoc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def bin_contacts(
file_manager = FileManager(use_dask=True)
contacts = file_manager.load_contacts(contact_path)
# binning
binner = GenomicBinner(
bin_size=bin_size
)
binner = GenomicBinner(bin_size=bin_size)
pixels = binner.bin_contacts(contacts, same_chromosome=same_chromosome)
# persisting
file_manager.write_pixels(pixel_path, pixels)
Expand All @@ -81,9 +79,7 @@ def merge_contacts(contact_paths, output):
"""Functionality to merge annotated fragments"""
file_manager = FileManager(use_dask=True)
manipulator = ContactManipulator()
contact_files = [
file_manager.load_contacts(path) for path in contact_paths
]
contact_files = [file_manager.load_contacts(path) for path in contact_paths]
merged = manipulator.merge_contacts(contact_files)
file_manager.write_multiway_contacts(output, merged)

Expand Down
236 changes: 164 additions & 72 deletions spoc/contacts.py

Large diffs are not rendered by default.

40 changes: 16 additions & 24 deletions spoc/dataframe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ class ContactSchema:
"mapping_quality": pa.Column(int),
"align_score": pa.Column(int),
"align_base_qscore": pa.Column(int),
"metadata": pa.Column(
str,
required=False
),
"metadata": pa.Column(str, required=False),
}

def __init__(self, number_fragments: int = 3, contains_metadata: bool = True) -> None:
def __init__(
self, number_fragments: int = 3, contains_metadata: bool = True
) -> None:
self._number_fragments = number_fragments
self._schema = pa.DataFrameSchema(
dict(
Expand Down Expand Up @@ -98,9 +97,7 @@ def validate_header(self, data_frame: DataFrame) -> None:
self._schema, data_frame, "Header is invalid!"
)

def validate(
self, data_frame: DataFrame
) -> DataFrame:
def validate(self, data_frame: DataFrame) -> DataFrame:
"""Validate multiway contact dataframe"""
self.validate_header(data_frame)
return self._schema.validate(data_frame)
Expand All @@ -126,23 +123,20 @@ def _get_contact_fields(self):
"start": pa.Column(int),
}
else:
return {
"chrom": pa.Column(str),
"start": pa.Column(int)
}
return {"chrom": pa.Column(str), "start": pa.Column(int)}

def _get_constant_fields(self):
if self._same_chromosome:
return {
"chrom": pa.Column(str),
"count": pa.Column(int),
"corrected_count": pa.Column(float, required=False),
}
return {
"chrom": pa.Column(str),
"count": pa.Column(int),
"corrected_count": pa.Column(float, required=False),
}
else:
return {
"count": pa.Column(int),
"corrected_count": pa.Column(float, required=False),
}
return {
"count": pa.Column(int),
"corrected_count": pa.Column(float, required=False),
}

def _expand_contact_fields(self, expansions: Iterable = (1, 2, 3)) -> dict:
"""adds suffixes to fields"""
Expand All @@ -161,8 +155,6 @@ def validate_header(self, data_frame: DataFrame) -> None:
self._schema, data_frame, "Header is invalid!"
)

def validate(
self, data_frame: DataFrame
) -> DataFrame:
def validate(self, data_frame: DataFrame) -> DataFrame:
"""Validate multiway contact dataframe"""
return self._schema.validate(data_frame)
5 changes: 4 additions & 1 deletion spoc/file_parameter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from pydantic import BaseModel
from typing import Optional, List


class ContactsParameters(BaseModel):
"""Parameters for multiway contacts"""

number_fragments: Optional[int] = None
metadata_combi: Optional[List[str]] = None
label_sorted: bool = False
Expand All @@ -14,10 +16,11 @@ class ContactsParameters(BaseModel):

class PixelParameters(BaseModel):
"""Parameters for genomic pixels"""

number_fragments: Optional[int] = None
binsize: Optional[int] = None
metadata_combi: Optional[List[str]] = None
label_sorted: bool = False
binary_labels_equal: bool = False
symmetry_flipped: bool = False
same_chromosome: bool = True
same_chromosome: bool = True
43 changes: 23 additions & 20 deletions spoc/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class Fragments:

def __init__(self, fragment_frame: DataFrame) -> None:
self._data = FragmentSchema.validate(fragment_frame)
self._contains_metadata = True if "metadata" in fragment_frame.columns else False
self._contains_metadata = (
True if "metadata" in fragment_frame.columns else False
)

@property
def data(self):
Expand All @@ -25,7 +27,7 @@ def data(self):
@property
def contains_metadata(self):
return self._contains_metadata

@property
def is_dask(self):
return isinstance(self._data, dd.DataFrame)
Expand Down Expand Up @@ -78,7 +80,6 @@ def annotate_fragments(self, fragments: Fragments) -> Fragments:
)



class FragmentExpander:
"""Expands n-way fragments over sequencing reads
to yield contacts."""
Expand All @@ -89,7 +90,7 @@ def __init__(self, number_fragments: int, contains_metadata: bool = True) -> Non
self._schema = ContactSchema(number_fragments, contains_metadata)

@staticmethod
def _add_suffix(row, suffix:int, contains_metadata:bool) -> Dict:
def _add_suffix(row, suffix: int, contains_metadata: bool) -> Dict:
"""expands contact fields"""
output = {}
for key in ContactSchema.get_contact_fields(contains_metadata):
Expand All @@ -98,9 +99,13 @@ def _add_suffix(row, suffix:int, contains_metadata:bool) -> Dict:

def _get_expansion_output_structure(self) -> pd.DataFrame:
"""returns expansion output dataframe structure for dask"""
return pd.DataFrame(columns=list(self._schema._schema.columns.keys()) + ['level_2']).set_index(["read_name", 'read_length', 'level_2'])
return pd.DataFrame(
columns=list(self._schema._schema.columns.keys()) + ["level_2"]
).set_index(["read_name", "read_length", "level_2"])

def _expand_single_read(self, read_df: pd.DataFrame, contains_metadata:bool) -> pd.DataFrame:
def _expand_single_read(
self, read_df: pd.DataFrame, contains_metadata: bool
) -> pd.DataFrame:
"""Expands a single read"""
if len(read_df) < self._number_fragments:
return pd.DataFrame()
Expand All @@ -116,9 +121,7 @@ def _expand_single_read(self, read_df: pd.DataFrame, contains_metadata:bool) ->
contact = {}
# add reads
for index, align in enumerate(alignments, start=1):
contact.update(
self._add_suffix(align, index, contains_metadata)
)
contact.update(self._add_suffix(align, index, contains_metadata))
result.append(contact)
return pd.DataFrame(result)

Expand All @@ -130,15 +133,15 @@ def expand(self, fragments: Fragments) -> Contacts:
else:
kwargs = dict()
# expand
contact_df = fragments.data\
.groupby(["read_name", "read_length"])\
.apply(self._expand_single_read,
contains_metadata=fragments.contains_metadata,
**kwargs)\
.reset_index()\
.drop("level_2", axis=1)
#return contact_df
return Contacts(
contact_df,
number_fragments=self._number_fragments
contact_df = (
fragments.data.groupby(["read_name", "read_length"])
.apply(
self._expand_single_read,
contains_metadata=fragments.contains_metadata,
**kwargs,
)
.reset_index()
.drop("level_2", axis=1)
)
# return contact_df
return Contacts(contact_df, number_fragments=self._number_fragments)
68 changes: 35 additions & 33 deletions spoc/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,33 @@
class FileManager:
"""Is responsible for loading and writing files"""

def __init__(
self, use_dask: bool = False
) -> None:
def __init__(self, use_dask: bool = False) -> None:
if use_dask:
self._parquet_reader_func = dd.read_parquet
else:
self._parquet_reader_func = pd.read_parquet

def _write_parquet_dask(self, path: str, df: dd.DataFrame, global_parameters: BaseModel) -> None:
def _write_parquet_dask(
self, path: str, df: dd.DataFrame, global_parameters: BaseModel
) -> None:
"""Write parquet file using dask"""
custom_meta_data = {
'spoc'.encode(): json.dumps(global_parameters.dict()).encode()
"spoc".encode(): json.dumps(global_parameters.dict()).encode()
}
dd.to_parquet(
df,
path,
custom_metadata=custom_meta_data
)

def _write_parquet_pandas(self, path: str, df: pd.DataFrame ,global_parameters: BaseModel) -> None:
dd.to_parquet(df, path, custom_metadata=custom_meta_data)

def _write_parquet_pandas(
self, path: str, df: pd.DataFrame, global_parameters: BaseModel
) -> None:
"""Write parquet file using pandas. Pyarrow is needed because
the pandas .to_parquet method does not support writing custom metadata."""
table = pa.Table.from_pandas(df)
# Add metadata
custom_meta_key = "spoc"
existing_meta = table.schema.metadata
combined_meta = {
custom_meta_key.encode() : json.dumps(global_parameters.dict()).encode(),
**existing_meta
custom_meta_key.encode(): json.dumps(global_parameters.dict()).encode(),
**existing_meta,
}
table = table.replace_schema_metadata(combined_meta)
# write table
Expand All @@ -60,7 +58,7 @@ def _load_parquet_global_parameters(path: str) -> BaseModel:
"""Load global parameters from parquet file"""
# check if path is a directory, if so, we need to read the schema from one of the partitioned files
if os.path.isdir(path):
path = path + '/' + os.listdir(path)[0]
path = path + "/" + os.listdir(path)[0]
global_parameters = pa.parquet.read_schema(path).metadata.get("spoc".encode())
if global_parameters is not None:
global_parameters = json.loads(global_parameters.decode())
Expand Down Expand Up @@ -91,25 +89,28 @@ def load_fragments(self, path: str):
def write_fragments(path: str, fragments: Fragments) -> None:
"""Write annotated fragments"""
# Write fragments
fragments.data.to_parquet(path, row_group_size=1024*1024)
fragments.data.to_parquet(path, row_group_size=1024 * 1024)

def write_multiway_contacts(self, path: str, contacts: Contacts) -> None:
"""Write multiway contacts"""
if contacts.is_dask:
self._write_parquet_dask(path, contacts.data, contacts.get_global_parameters())
self._write_parquet_dask(
path, contacts.data, contacts.get_global_parameters()
)
else:
self._write_parquet_pandas(path, contacts.data, contacts.get_global_parameters())

self._write_parquet_pandas(
path, contacts.data, contacts.get_global_parameters()
)

def load_contacts(self, path: str, global_parameters: Optional[ContactsParameters] = None) -> Contacts:
def load_contacts(
self, path: str, global_parameters: Optional[ContactsParameters] = None
) -> Contacts:
"""Load multiway contacts"""
if global_parameters is None:
global_parameters = self._load_parquet_global_parameters(path)
else:
global_parameters = global_parameters.dict()
return Contacts(
self._parquet_reader_func(path), **global_parameters
)
return Contacts(self._parquet_reader_func(path), **global_parameters)

@staticmethod
def load_chromosome_sizes(path: str):
Expand All @@ -134,23 +135,22 @@ def _load_pixel_metadata(path: str):
else:
raise ValueError(f"Metadata file not found at {metadata_path}")
return metadata

@staticmethod
def list_pixels(path: str):
"""List available pixels"""
# read metadata.json
metadata = FileManager._load_pixel_metadata(path)
# instantiate pixel parameters
pixels = [
PixelParameters(**params) for params in metadata.values()
]
pixels = [PixelParameters(**params) for params in metadata.values()]
return pixels


def load_pixels(self, path: str, global_parameters: PixelParameters, load_dataframe:bool = True) -> Pixels:
def load_pixels(
self, path: str, global_parameters: PixelParameters, load_dataframe: bool = True
) -> Pixels:
"""Loads specific pixels instance based on global parameters.
load_dataframe specifies whether the dataframe should be loaded, or whether pixels
should be instantiated based on the path alone. """
should be instantiated based on the path alone."""
metadata = self._load_pixel_metadata(path)
# find matching pixels
for pixel_path, value in metadata.items():
Expand Down Expand Up @@ -187,9 +187,11 @@ def write_pixels(self, path: str, pixels: Pixels) -> None:
write_path = Path(path) / self._get_pixel_hash_path(path, pixels)
# write pixels
if pixels.data is None:
raise ValueError("Writing pixels only suppported for pixels hodling dataframes!")
pixels.data.to_parquet(write_path, row_group_size=1024*1024)
raise ValueError(
"Writing pixels only suppported for pixels hodling dataframes!"
)
pixels.data.to_parquet(write_path, row_group_size=1024 * 1024)
# write metadata
current_metadata[write_path.name] = pixels.get_global_parameters().dict()
with open(metadata_path, "w") as f:
json.dump(current_metadata, f)
json.dump(current_metadata, f)
Loading

0 comments on commit 39bcc93

Please sign in to comment.