Skip to content

Commit

Permalink
Correct G2 & chi2 computation on full contingency matrix
Browse files Browse the repository at this point in the history
Related to #383
  • Loading branch information
Yomguithereal committed Nov 21, 2024
1 parent 20dde77 commit 3c8e34a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 55 deletions.
127 changes: 75 additions & 52 deletions src/cmd/vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ vocab cooc options:
--distrib Compute directed distributional similarity metrics instead.
--min-count <n> Minimum number of co-occurrence count to be included in the result.
[default: 1]
--complete Compute the complete chi2 & G2 metrics, instead of their approximation
based on the first cell of the contingency matrix. This
is of course more costly to compute.
Common options:
-h, --help Display this message
Expand Down Expand Up @@ -143,7 +140,6 @@ struct Args {
flag_forward: bool,
flag_distrib: bool,
flag_min_count: usize,
flag_complete: bool,
flag_output: Option<String>,
flag_no_headers: bool,
flag_delimiter: Option<Delimiter>,
Expand Down Expand Up @@ -342,10 +338,6 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
};

if args.flag_distrib {
if args.flag_complete {
unimplemented!();
}

let output_headers: [&[u8]; 5] = [b"token1", b"token2", b"count", b"sdI", b"sdG2"];

wtr.write_record(output_headers)?;
Expand All @@ -357,9 +349,8 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
];

wtr.write_record(output_headers)?;
cooccurrences.for_each_cooc_record(args.flag_min_count, args.flag_complete, |r| {
wtr.write_byte_record(r)
})?;
cooccurrences
.for_each_cooc_record(args.flag_min_count, |r| wtr.write_byte_record(r))?;
}

return Ok(wtr.flush()?);
Expand Down Expand Up @@ -765,17 +756,17 @@ fn compute_npmi(xy: usize, n: usize, pmi: f64) -> f64 {
}
}

#[inline]
fn compute_simplified_chi2_and_g2(x: usize, y: usize, xy: usize, n: usize) -> (f64, f64) {
// This version does not take into account the full contingency matrix.
let observed = xy as f64;
let expected = x as f64 * y as f64 / n as f64;
// #[inline]
// fn compute_simplified_chi2_and_g2(x: usize, y: usize, xy: usize, n: usize) -> (f64, f64) {
// // This version does not take into account the full contingency matrix.
// let observed = xy as f64;
// let expected = x as f64 * y as f64 / n as f64;

(
(observed - expected).powi(2) / expected,
2.0 * observed * (observed / expected).ln(),
)
}
// (
// (observed - expected).powi(2) / expected,
// 2.0 * observed * (observed / expected).ln(),
// )
// }

#[inline]
fn compute_simplified_g2(x: usize, y: usize, xy: usize, n: usize) -> f64 {
Expand All @@ -787,35 +778,82 @@ fn compute_simplified_g2(x: usize, y: usize, xy: usize, n: usize) -> f64 {
}

// NOTE: see code in issue https://github.com/medialab/xan/issues/295
// NOTE: it is possible to approximate chi2 and G2 for co-occurrences by
// only computing the (observed_11, expected_11) part related to the first
// cell of the contingency matrix. This works very well for chi2, but
// is a little bit more fuzzy for G2.
fn compute_chi2_and_g2(x: usize, y: usize, xy: usize, n: usize) -> (f64, f64) {
// This can be 0 if some item is present in all co-occurrences!
let not_x = (n - x) as f64;
let not_y = (n - y) as f64;
let nf = n as f64;

let observed_11 = xy as f64;
let observed_12 = (x - xy) as f64;
let observed_21 = (y - xy) as f64;
let observed_22 = (n - (x + y) + xy) as f64;
let observed_12 = (x - xy) as f64; // Is 0 if x only co-occurs with y
let observed_21 = (y - xy) as f64; // Is 0 if y only co-occurs with x

// NOTE: with few co-occurrences, self loops can produce a negative
// outcome...
let observed_22 = ((n + xy) as i64 - (x + y) as i64) as f64;

let nf = n as f64;

let expected_11 = x as f64 * y as f64 / nf;
let expected_11 = x as f64 * y as f64 / nf; // Cannot be 0
let expected_12 = x as f64 * not_y / nf;
let expected_21 = y as f64 * not_x / nf;
let expected_22 = not_x * not_y / nf;

debug_assert!(
observed_11 >= 0.0
&& observed_12 >= 0.0
&& observed_21 >= 0.0
// && observed_22 >= 0.0
&& expected_11 >= 0.0
&& expected_12 >= 0.0
&& expected_21 >= 0.0
&& expected_22 >= 0.0
);

let chi2_11 = (observed_11 - expected_11).powi(2) / expected_11;
let chi2_12 = (observed_12 - expected_12).powi(2) / expected_12;
let chi2_21 = (observed_21 - expected_21).powi(2) / expected_21;
let chi2_22 = (observed_22 - expected_22).powi(2) / expected_22;

let g2_11 = observed_11 * (observed_11 / expected_11).ln();
let g2_12 = observed_12 * (observed_12 / expected_12).ln();
let g2_21 = observed_21 * (observed_21 / expected_21).ln();
let g2_22 = observed_22 * (observed_22 / expected_22).ln();

(
chi2_11 + chi2_12 + chi2_21 + chi2_22,
2.0 * (g2_11 + g2_12 + g2_21 + g2_22),
)
let g2_12 = if observed_12 == 0.0 {
0.0
} else {
observed_12 * (observed_12 / expected_12).ln()
};
let g2_21 = if observed_21 == 0.0 {
0.0
} else {
observed_21 * (observed_21 / expected_21).ln()
};
let g2_22 = if observed_22 <= 0.0 {
0.0
} else {
observed_22 * (observed_22 / expected_22).ln()
};

let mut chi2 = chi2_11 + chi2_12 + chi2_21 + chi2_22;
let mut g2 = 2.0 * (g2_11 + g2_12 + g2_21 + g2_22);

// Dealing with degenerate cases that happen when the number
// of co-occurrences is very low, or when some item dominates
// the distribution.
if chi2.is_nan() {
chi2 = 0.0;
}

if chi2.is_infinite() {
chi2 = chi2_11;
}

if g2.is_infinite() {
g2 = g2_11;
}

(chi2, g2)
}

#[derive(Debug)]
Expand Down Expand Up @@ -912,12 +950,7 @@ impl Cooccurrences {
target_entry.gcf += 1;
}

fn for_each_cooc_record<F, E>(
self,
min_count: usize,
complete: bool,
mut callback: F,
) -> Result<(), E>
fn for_each_cooc_record<F, E>(self, min_count: usize, mut callback: F) -> Result<(), E>
where
F: FnMut(&csv::ByteRecord) -> Result<(), E>,
{
Expand All @@ -938,11 +971,7 @@ impl Cooccurrences {
let xy = *count;

// chi2/G2 computations
let (chi2, g2) = if complete {
compute_chi2_and_g2(x, y, xy, n)
} else {
compute_simplified_chi2_and_g2(x, y, xy, n)
};
let (chi2, g2) = compute_chi2_and_g2(x, y, xy, n);

// PMI-related computations
let pmi = compute_pmi(x, y, xy, n);
Expand All @@ -953,13 +982,7 @@ impl Cooccurrences {
csv_record.push_field(&target_entry.token);
csv_record.push_field(count.to_string().as_bytes());
csv_record.push_field(chi2.to_string().as_bytes());

if g2.is_nan() {
csv_record.push_field(b"");
} else {
csv_record.push_field(g2.to_string().as_bytes());
}

csv_record.push_field(g2.to_string().as_bytes());
csv_record.push_field(pmi.to_string().as_bytes());
csv_record.push_field(npmi.to_string().as_bytes());

Expand Down
6 changes: 3 additions & 3 deletions tests/test_vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ fn vocab_cooc_sep_no_doc() {

let expected = vec![
svec!["token1", "token2", "count", "chi2", "G2", "pmi", "npmi"],
svec!["cat", "cat", "1", "2.25", "-2.772588722239781", "-2", "-1"],
svec!["cat", "cat", "1", "2.25", "-1.3862943611198906", "-2", "-1"],
svec!["cat", "dog", "2", "0", "0", "0", "0"],
svec!["cat", "rabbit", "1", "0", "0", "0", "0"],
];
Expand Down Expand Up @@ -262,7 +262,7 @@ fn vocab_cooc_no_sep() {

let expected = vec![
svec!["token1", "token2", "count", "chi2", "G2", "pmi", "npmi"],
svec!["cat", "cat", "1", "2.25", "-2.772588722239781", "-2", "-1"],
svec!["cat", "cat", "1", "2.25", "-1.3862943611198906", "-2", "-1"],
svec!["cat", "dog", "2", "0", "0", "0", "0"],
svec!["cat", "rabbit", "1", "0", "0", "0", "0"],
];
Expand Down Expand Up @@ -294,7 +294,7 @@ fn vocab_cooc_no_sep_window() {

let expected = vec![
svec!["token1", "token2", "count", "chi2", "G2", "pmi", "npmi"],
svec!["cat", "cat", "1", "2.25", "-2.772588722239781", "-2", "-1"],
svec!["cat", "cat", "1", "2.25", "-1.3862943611198906", "-2", "-1"],
svec!["cat", "dog", "2", "0", "0", "0", "0"],
svec!["cat", "rabbit", "1", "0", "0", "0", "0"],
];
Expand Down

0 comments on commit 3c8e34a

Please sign in to comment.