diff --git a/CHANGELOG.md b/CHANGELOG.md index f2b2ca2..e9602f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Fixed +- Now removes claims that are only connected to deleted tweets when calling + `to_dgl`. This previously caused a bug that was due to a mismatch between + nodes in the dataset (which includes deleted ones) and nodes in the DGL graph + (which does not contain the deleted ones). + + ## [v1.6.1] - 2022-03-17 ### Fixed - Now correctly catches JSONDecodeError during rehydration. diff --git a/mumin/dgl.py b/mumin/dgl.py index 26d2039..ce3f45f 100644 --- a/mumin/dgl.py +++ b/mumin/dgl.py @@ -44,6 +44,14 @@ def build_dgl_dataset(nodes: Dict[str, pd.DataFrame], '`dgl` extension, like so: `pip install ' 'mumin[dgl]`') + # Remove the claims that are only connected to deleted tweets + tweet_df = nodes['tweet'].dropna() + claim_df = nodes['claim'] + discusses_df = relations[('tweet', 'discusses', 'claim')] + discusses_df = discusses_df[discusses_df.src.isin(tweet_df.index.tolist())] + claim_df = claim_df[claim_df.index.isin(discusses_df.tgt.tolist())] + nodes['claim'] = claim_df + # Set up the graph as a DGL graph graph_data = dict() for canonical_etype, rel_arr in relations.items(): @@ -66,8 +74,8 @@ def build_dgl_dataset(nodes: Dict[str, pd.DataFrame], # Get a dataframe containing the edges between allowed source and # target nodes (i.e., non-deleted) rel_arr = (relations[canonical_etype][['src', 'tgt']] - .query('src in @allowed_src.values() and ' - 'tgt in @allowed_tgt.values()') + .query('src in @allowed_src.keys() and ' + 'tgt in @allowed_tgt.keys()') .drop_duplicates()) # Convert the node indices in the edge dataframe to the new indices