diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 68be349cb032..a26adcdd84e7 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -1160,14 +1160,41 @@ int64_t TemporalNumPick( const torch::optional& edge_timestamp, int64_t seed_offset, int64_t offset, int64_t num_neighbors) { constexpr int64_t kFastPathThreshold = 1000; - if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) { + if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value() && + fanout != -1) { // TODO: Currently we use the fast path both in TemporalNumPick and // TemporalPick. We may only sample once in TemporalNumPick and use the // sampled edges in TemporalPick to avoid sampling twice. - auto [success, sampled_edges] = FastTemporalPick( - seed_timestamp, csc_indics, fanout, replace, seed_pre_time_window, - node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors); - if (success) return sampled_edges.size(); + int64_t sampled_count = 0; + auto timestamp = + utils::GetValueByIndex(seed_timestamp, seed_offset); + for (int64_t edge_id = offset; + edge_id < offset + num_neighbors && sampled_count < fanout; + edge_id++) { + if (replace && sampled_count > 0) { + sampled_count = fanout; + break; + } + if (node_timestamp.has_value()) { + bool flag = true; + AT_DISPATCH_INDEX_TYPES( + csc_indics.scalar_type(), "CheckNodeTimeStamp", ([&] { + int64_t neighbor_id = + utils::GetValueByIndex(csc_indics, edge_id); + if (utils::GetValueByIndex( + node_timestamp.value(), neighbor_id) >= timestamp) + flag = false; + })); + if (!flag) continue; + } + if (edge_timestamp.has_value() && + utils::GetValueByIndex(edge_timestamp.value(), edge_id) >= + timestamp) { + continue; + } + sampled_count++; + } + return sampled_count; } torch::optional time_window = torch::nullopt; if (seed_pre_time_window.has_value()) {