Skip to main content

sc_neurocore_engine/
topology.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Rust Ollivier-Ricci curvature (parity with src/sc_neurocore/math/topology.py)
8
9//! Rust implementation of discrete Ollivier-Ricci curvature on a
10//! coupling graph.
11//!
12//! Parity target: `ollivier_ricci_curvature` in
13//! `src/sc_neurocore/math/topology.py`. For the same coupling matrix
14//! and node pair, the Rust and Python paths return the same value to
15//! within float64 round-off.
16//!
17//! References:
18//!   Ollivier, Y. (2009). "Ricci curvature of Markov chains on metric
19//!     spaces." J. Functional Analysis 256(3): 810-864.
20//!
21//! Definition:
22//!   kappa(i, j) = 1 - W1(mu_i, mu_j) / d(i, j)
23//! where d is the unweighted shortest-path (hop) distance, mu_i is the
24//! lazy random walk distribution from node i (idleness 1/2, the rest
25//! split by outgoing coupling weight), and W1 is the Wasserstein-1
26//! (earth-mover) distance under the hop metric, solved exactly by a
27//! successive-shortest-path min-cost flow.
28//!
29//! The min-cost-flow loop mirrors the Python reference iteration order
30//! (Bellman-Ford with ascending node scan) so the chosen augmenting
31//! paths — and therefore the floating-point accumulation of the
32//! transport cost — are identical across the two implementations.
33
34const IDLENESS: f64 = 0.5;
35const TOLERANCE: f64 = 1e-12;
36
37/// Outcome of an Ollivier-Ricci curvature request.
38#[derive(Debug, Clone, Copy, PartialEq)]
39pub enum CurvatureError {
40    /// `knm` is not a square, non-empty matrix.
41    BadShape,
42    /// `knm` carries a non-finite or negative entry.
43    BadValue,
44    /// A node index is outside `[0, n)`.
45    BadIndex,
46    /// The transport sub-problem admitted no augmenting path.
47    Infeasible,
48}
49
50/// Unweighted (hop-count) all-pairs shortest-path distances via BFS.
51///
52/// `graph` is row-major `n x n`. An edge exists where `graph[u*n+v] > 0`
53/// (the diagonal is ignored). Unreachable pairs stay `f64::INFINITY`.
54fn shortest_path_distances(graph: &[f64], n: usize) -> Vec<f64> {
55    let mut distances = vec![f64::INFINITY; n * n];
56    let mut queue: Vec<usize> = Vec::with_capacity(n);
57    for source in 0..n {
58        distances[source * n + source] = 0.0;
59        queue.clear();
60        queue.push(source);
61        let mut head = 0;
62        while head < queue.len() {
63            let current = queue[head];
64            head += 1;
65            let next_distance = distances[source * n + current] + 1.0;
66            for target in 0..n {
67                if target == current || graph[current * n + target] <= 0.0 {
68                    continue;
69                }
70                if next_distance < distances[source * n + target] {
71                    distances[source * n + target] = next_distance;
72                    queue.push(target);
73                }
74            }
75        }
76    }
77    distances
78}
79
80/// Lazy random walk distribution from `node`: probability `IDLENESS`
81/// stays, the remainder is split across outgoing edges by coupling
82/// weight. A sink (no outgoing weight) keeps all mass at `node`.
83fn lazy_random_walk(graph: &[f64], n: usize, node: usize) -> Vec<f64> {
84    let mut distribution = vec![0.0; n];
85    distribution[node] = IDLENESS;
86    let mut row_sum = 0.0;
87    for k in 0..n {
88        if k != node {
89            row_sum += graph[node * n + k];
90        }
91    }
92    if row_sum == 0.0 {
93        distribution[node] = 1.0;
94        return distribution;
95    }
96    for k in 0..n {
97        if k != node {
98            distribution[k] += (1.0 - IDLENESS) * graph[node * n + k] / row_sum;
99        }
100    }
101    distribution
102}
103
104/// Exact Wasserstein-1 cost between two distributions under the hop
105/// metric `distances`, via a successive-shortest-path min-cost flow.
106///
107/// Returns `Ok(f64::INFINITY)` when some required transport edge has
108/// infinite (unreachable) cost, matching the Python early return.
109fn minimum_transport_cost(
110    source: &[f64],
111    target: &[f64],
112    distances: &[f64],
113    n: usize,
114) -> Result<f64, CurvatureError> {
115    let source_nodes: Vec<usize> = (0..n).filter(|&k| source[k] > 0.0).collect();
116    let target_nodes: Vec<usize> = (0..n).filter(|&k| target[k] > 0.0).collect();
117    if source_nodes.is_empty() || target_nodes.is_empty() {
118        return Ok(0.0);
119    }
120
121    let total_supply = source_nodes.len();
122    let total_demand = target_nodes.len();
123    let mut costs = vec![0.0; total_supply * total_demand];
124    for (s_idx, &s_node) in source_nodes.iter().enumerate() {
125        for (d_idx, &d_node) in target_nodes.iter().enumerate() {
126            let cost = distances[s_node * n + d_node];
127            if !cost.is_finite() {
128                return Ok(f64::INFINITY);
129            }
130            costs[s_idx * total_demand + d_idx] = cost;
131        }
132    }
133
134    let source_id = total_supply + total_demand;
135    let sink_id = source_id + 1;
136    let node_count = sink_id + 1;
137    let mut residual = vec![0.0; node_count * node_count];
138    let mut edge_cost = vec![0.0; node_count * node_count];
139
140    for (idx, &s_node) in source_nodes.iter().enumerate() {
141        residual[source_id * node_count + idx] = source[s_node];
142    }
143    for (idx, &d_node) in target_nodes.iter().enumerate() {
144        residual[(total_supply + idx) * node_count + sink_id] = target[d_node];
145    }
146    for s_idx in 0..total_supply {
147        for d_idx in 0..total_demand {
148            let u = s_idx;
149            let v = total_supply + d_idx;
150            let cost = costs[s_idx * total_demand + d_idx];
151            residual[u * node_count + v] = f64::INFINITY;
152            edge_cost[u * node_count + v] = cost;
153            edge_cost[v * node_count + u] = -cost;
154        }
155    }
156
157    let required: f64 = source.iter().sum();
158    let mut transported = 0.0;
159    let mut total_cost = 0.0;
160
161    while transported + TOLERANCE < required {
162        let mut dist = vec![f64::INFINITY; node_count];
163        let mut parent = vec![usize::MAX; node_count];
164        dist[source_id] = 0.0;
165        for _ in 0..node_count - 1 {
166            let mut updated = false;
167            for u in 0..node_count {
168                if !dist[u].is_finite() {
169                    continue;
170                }
171                for v in 0..node_count {
172                    if residual[u * node_count + v] <= TOLERANCE {
173                        continue;
174                    }
175                    let candidate = dist[u] + edge_cost[u * node_count + v];
176                    if candidate < dist[v] - TOLERANCE {
177                        dist[v] = candidate;
178                        parent[v] = u;
179                        updated = true;
180                    }
181                }
182            }
183            if !updated {
184                break;
185            }
186        }
187        if parent[sink_id] == usize::MAX {
188            return Err(CurvatureError::Infeasible);
189        }
190
191        let mut increment = required - transported;
192        let mut v = sink_id;
193        while v != source_id {
194            let u = parent[v];
195            increment = increment.min(residual[u * node_count + v]);
196            v = u;
197        }
198        let mut v = sink_id;
199        while v != source_id {
200            let u = parent[v];
201            residual[u * node_count + v] -= increment;
202            residual[v * node_count + u] += increment;
203            total_cost += increment * edge_cost[u * node_count + v];
204            v = u;
205        }
206        transported += increment;
207    }
208    Ok(total_cost)
209}
210
211/// Discrete Ollivier-Ricci curvature between nodes `i` and `j`.
212///
213/// `knm` is row-major `n x n` and must be square, finite, and
214/// non-negative (validated here so the PyO3 wrapper and the pure-Rust
215/// callers share one contract). Returns `0.0` for `i == j` and for
216/// node pairs at zero or infinite graph distance, matching the Python
217/// reference.
218pub fn ollivier_ricci_curvature(
219    knm: &[f64],
220    n: usize,
221    i: usize,
222    j: usize,
223) -> Result<f64, CurvatureError> {
224    if n == 0 || knm.len() != n * n {
225        return Err(CurvatureError::BadShape);
226    }
227    for &value in knm {
228        if !value.is_finite() || value < 0.0 {
229            return Err(CurvatureError::BadValue);
230        }
231    }
232    if i >= n || j >= n {
233        return Err(CurvatureError::BadIndex);
234    }
235    if i == j {
236        return Ok(0.0);
237    }
238
239    let distances = shortest_path_distances(knm, n);
240    let graph_distance = distances[i * n + j];
241    if !graph_distance.is_finite() || graph_distance <= 0.0 {
242        return Ok(0.0);
243    }
244    let mu_i = lazy_random_walk(knm, n, i);
245    let mu_j = lazy_random_walk(knm, n, j);
246    let w1 = minimum_transport_cost(&mu_i, &mu_j, &distances, n)?;
247    Ok(1.0 - w1 / graph_distance)
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    fn complete_graph(n: usize) -> Vec<f64> {
255        let mut g = vec![1.0; n * n];
256        for k in 0..n {
257            g[k * n + k] = 0.0;
258        }
259        g
260    }
261
262    #[test]
263    fn self_pair_is_zero() {
264        let g = complete_graph(4);
265        assert_eq!(ollivier_ricci_curvature(&g, 4, 2, 2).unwrap(), 0.0);
266    }
267
268    #[test]
269    fn complete_graph_is_positively_curved() {
270        let g = complete_graph(5);
271        let kappa = ollivier_ricci_curvature(&g, 5, 0, 1).unwrap();
272        assert!(kappa > 0.0, "complete-graph curvature {kappa} not positive");
273    }
274
275    #[test]
276    fn disconnected_pair_returns_zero() {
277        // Two isolated edges: 0-1 and 2-3. Pair (0, 2) is disconnected.
278        let mut g = vec![0.0; 16];
279        g[0 * 4 + 1] = 1.0;
280        g[1 * 4 + 0] = 1.0;
281        g[2 * 4 + 3] = 1.0;
282        g[3 * 4 + 2] = 1.0;
283        let kappa = ollivier_ricci_curvature(&g, 4, 0, 2).unwrap();
284        assert_eq!(kappa, 0.0);
285    }
286
287    #[test]
288    fn ring_is_less_curved_than_complete() {
289        // 6-cycle.
290        let n = 6;
291        let mut ring = vec![0.0; n * n];
292        for k in 0..n {
293            let a = k;
294            let b = (k + 1) % n;
295            ring[a * n + b] = 1.0;
296            ring[b * n + a] = 1.0;
297        }
298        let kappa_ring = ollivier_ricci_curvature(&ring, n, 0, 1).unwrap();
299        let complete = complete_graph(n);
300        let kappa_complete = ollivier_ricci_curvature(&complete, n, 0, 1).unwrap();
301        assert!(kappa_ring < kappa_complete);
302    }
303
304    #[test]
305    fn rejects_bad_shape() {
306        let g = vec![0.0; 6];
307        assert_eq!(
308            ollivier_ricci_curvature(&g, 3, 0, 1),
309            Err(CurvatureError::BadShape)
310        );
311    }
312
313    #[test]
314    fn rejects_negative_entry() {
315        let mut g = complete_graph(3);
316        g[1] = -1.0;
317        assert_eq!(
318            ollivier_ricci_curvature(&g, 3, 0, 1),
319            Err(CurvatureError::BadValue)
320        );
321    }
322
323    #[test]
324    fn rejects_out_of_range_index() {
325        let g = complete_graph(3);
326        assert_eq!(
327            ollivier_ricci_curvature(&g, 3, 0, 5),
328            Err(CurvatureError::BadIndex)
329        );
330    }
331}