sc_neurocore_engine/
topology.rs1const IDLENESS: f64 = 0.5;
35const TOLERANCE: f64 = 1e-12;
36
37#[derive(Debug, Clone, Copy, PartialEq)]
39pub enum CurvatureError {
40 BadShape,
42 BadValue,
44 BadIndex,
46 Infeasible,
48}
49
50fn shortest_path_distances(graph: &[f64], n: usize) -> Vec<f64> {
55 let mut distances = vec![f64::INFINITY; n * n];
56 let mut queue: Vec<usize> = Vec::with_capacity(n);
57 for source in 0..n {
58 distances[source * n + source] = 0.0;
59 queue.clear();
60 queue.push(source);
61 let mut head = 0;
62 while head < queue.len() {
63 let current = queue[head];
64 head += 1;
65 let next_distance = distances[source * n + current] + 1.0;
66 for target in 0..n {
67 if target == current || graph[current * n + target] <= 0.0 {
68 continue;
69 }
70 if next_distance < distances[source * n + target] {
71 distances[source * n + target] = next_distance;
72 queue.push(target);
73 }
74 }
75 }
76 }
77 distances
78}
79
80fn lazy_random_walk(graph: &[f64], n: usize, node: usize) -> Vec<f64> {
84 let mut distribution = vec![0.0; n];
85 distribution[node] = IDLENESS;
86 let mut row_sum = 0.0;
87 for k in 0..n {
88 if k != node {
89 row_sum += graph[node * n + k];
90 }
91 }
92 if row_sum == 0.0 {
93 distribution[node] = 1.0;
94 return distribution;
95 }
96 for k in 0..n {
97 if k != node {
98 distribution[k] += (1.0 - IDLENESS) * graph[node * n + k] / row_sum;
99 }
100 }
101 distribution
102}
103
104fn minimum_transport_cost(
110 source: &[f64],
111 target: &[f64],
112 distances: &[f64],
113 n: usize,
114) -> Result<f64, CurvatureError> {
115 let source_nodes: Vec<usize> = (0..n).filter(|&k| source[k] > 0.0).collect();
116 let target_nodes: Vec<usize> = (0..n).filter(|&k| target[k] > 0.0).collect();
117 if source_nodes.is_empty() || target_nodes.is_empty() {
118 return Ok(0.0);
119 }
120
121 let total_supply = source_nodes.len();
122 let total_demand = target_nodes.len();
123 let mut costs = vec![0.0; total_supply * total_demand];
124 for (s_idx, &s_node) in source_nodes.iter().enumerate() {
125 for (d_idx, &d_node) in target_nodes.iter().enumerate() {
126 let cost = distances[s_node * n + d_node];
127 if !cost.is_finite() {
128 return Ok(f64::INFINITY);
129 }
130 costs[s_idx * total_demand + d_idx] = cost;
131 }
132 }
133
134 let source_id = total_supply + total_demand;
135 let sink_id = source_id + 1;
136 let node_count = sink_id + 1;
137 let mut residual = vec![0.0; node_count * node_count];
138 let mut edge_cost = vec![0.0; node_count * node_count];
139
140 for (idx, &s_node) in source_nodes.iter().enumerate() {
141 residual[source_id * node_count + idx] = source[s_node];
142 }
143 for (idx, &d_node) in target_nodes.iter().enumerate() {
144 residual[(total_supply + idx) * node_count + sink_id] = target[d_node];
145 }
146 for s_idx in 0..total_supply {
147 for d_idx in 0..total_demand {
148 let u = s_idx;
149 let v = total_supply + d_idx;
150 let cost = costs[s_idx * total_demand + d_idx];
151 residual[u * node_count + v] = f64::INFINITY;
152 edge_cost[u * node_count + v] = cost;
153 edge_cost[v * node_count + u] = -cost;
154 }
155 }
156
157 let required: f64 = source.iter().sum();
158 let mut transported = 0.0;
159 let mut total_cost = 0.0;
160
161 while transported + TOLERANCE < required {
162 let mut dist = vec![f64::INFINITY; node_count];
163 let mut parent = vec![usize::MAX; node_count];
164 dist[source_id] = 0.0;
165 for _ in 0..node_count - 1 {
166 let mut updated = false;
167 for u in 0..node_count {
168 if !dist[u].is_finite() {
169 continue;
170 }
171 for v in 0..node_count {
172 if residual[u * node_count + v] <= TOLERANCE {
173 continue;
174 }
175 let candidate = dist[u] + edge_cost[u * node_count + v];
176 if candidate < dist[v] - TOLERANCE {
177 dist[v] = candidate;
178 parent[v] = u;
179 updated = true;
180 }
181 }
182 }
183 if !updated {
184 break;
185 }
186 }
187 if parent[sink_id] == usize::MAX {
188 return Err(CurvatureError::Infeasible);
189 }
190
191 let mut increment = required - transported;
192 let mut v = sink_id;
193 while v != source_id {
194 let u = parent[v];
195 increment = increment.min(residual[u * node_count + v]);
196 v = u;
197 }
198 let mut v = sink_id;
199 while v != source_id {
200 let u = parent[v];
201 residual[u * node_count + v] -= increment;
202 residual[v * node_count + u] += increment;
203 total_cost += increment * edge_cost[u * node_count + v];
204 v = u;
205 }
206 transported += increment;
207 }
208 Ok(total_cost)
209}
210
211pub fn ollivier_ricci_curvature(
219 knm: &[f64],
220 n: usize,
221 i: usize,
222 j: usize,
223) -> Result<f64, CurvatureError> {
224 if n == 0 || knm.len() != n * n {
225 return Err(CurvatureError::BadShape);
226 }
227 for &value in knm {
228 if !value.is_finite() || value < 0.0 {
229 return Err(CurvatureError::BadValue);
230 }
231 }
232 if i >= n || j >= n {
233 return Err(CurvatureError::BadIndex);
234 }
235 if i == j {
236 return Ok(0.0);
237 }
238
239 let distances = shortest_path_distances(knm, n);
240 let graph_distance = distances[i * n + j];
241 if !graph_distance.is_finite() || graph_distance <= 0.0 {
242 return Ok(0.0);
243 }
244 let mu_i = lazy_random_walk(knm, n, i);
245 let mu_j = lazy_random_walk(knm, n, j);
246 let w1 = minimum_transport_cost(&mu_i, &mu_j, &distances, n)?;
247 Ok(1.0 - w1 / graph_distance)
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 fn complete_graph(n: usize) -> Vec<f64> {
255 let mut g = vec![1.0; n * n];
256 for k in 0..n {
257 g[k * n + k] = 0.0;
258 }
259 g
260 }
261
262 #[test]
263 fn self_pair_is_zero() {
264 let g = complete_graph(4);
265 assert_eq!(ollivier_ricci_curvature(&g, 4, 2, 2).unwrap(), 0.0);
266 }
267
268 #[test]
269 fn complete_graph_is_positively_curved() {
270 let g = complete_graph(5);
271 let kappa = ollivier_ricci_curvature(&g, 5, 0, 1).unwrap();
272 assert!(kappa > 0.0, "complete-graph curvature {kappa} not positive");
273 }
274
275 #[test]
276 fn disconnected_pair_returns_zero() {
277 let mut g = vec![0.0; 16];
279 g[0 * 4 + 1] = 1.0;
280 g[1 * 4 + 0] = 1.0;
281 g[2 * 4 + 3] = 1.0;
282 g[3 * 4 + 2] = 1.0;
283 let kappa = ollivier_ricci_curvature(&g, 4, 0, 2).unwrap();
284 assert_eq!(kappa, 0.0);
285 }
286
287 #[test]
288 fn ring_is_less_curved_than_complete() {
289 let n = 6;
291 let mut ring = vec![0.0; n * n];
292 for k in 0..n {
293 let a = k;
294 let b = (k + 1) % n;
295 ring[a * n + b] = 1.0;
296 ring[b * n + a] = 1.0;
297 }
298 let kappa_ring = ollivier_ricci_curvature(&ring, n, 0, 1).unwrap();
299 let complete = complete_graph(n);
300 let kappa_complete = ollivier_ricci_curvature(&complete, n, 0, 1).unwrap();
301 assert!(kappa_ring < kappa_complete);
302 }
303
304 #[test]
305 fn rejects_bad_shape() {
306 let g = vec![0.0; 6];
307 assert_eq!(
308 ollivier_ricci_curvature(&g, 3, 0, 1),
309 Err(CurvatureError::BadShape)
310 );
311 }
312
313 #[test]
314 fn rejects_negative_entry() {
315 let mut g = complete_graph(3);
316 g[1] = -1.0;
317 assert_eq!(
318 ollivier_ricci_curvature(&g, 3, 0, 1),
319 Err(CurvatureError::BadValue)
320 );
321 }
322
323 #[test]
324 fn rejects_out_of_range_index() {
325 let g = complete_graph(3);
326 assert_eq!(
327 ollivier_ricci_curvature(&g, 3, 0, 5),
328 Err(CurvatureError::BadIndex)
329 );
330 }
331}