1use rand::{RngExt, SeedableRng};
14use rand_chacha::ChaCha8Rng;
15use rayon::prelude::*;
16
17#[derive(Clone, Debug)]
19pub struct CsrMatrix {
20 pub row_offsets: Vec<usize>,
22 pub col_indices: Vec<usize>,
23 pub values: Vec<f64>,
24 pub n_rows: usize,
25 pub n_cols: usize,
26}
27
28impl CsrMatrix {
29 pub fn new(
30 row_offsets: Vec<usize>,
31 col_indices: Vec<usize>,
32 values: Vec<f64>,
33 n_rows: usize,
34 n_cols: usize,
35 ) -> Result<Self, String> {
36 if row_offsets.len() != n_rows + 1 {
37 return Err(format!(
38 "row_offsets length {} != n_rows + 1 = {}",
39 row_offsets.len(),
40 n_rows + 1
41 ));
42 }
43 if col_indices.len() != values.len() {
44 return Err(format!(
45 "col_indices length {} != values length {}",
46 col_indices.len(),
47 values.len()
48 ));
49 }
50 let nnz = *row_offsets.last().ok_or("row_offsets must not be empty")?;
51 if col_indices.len() != nnz {
52 return Err(format!(
53 "col_indices length {} != nnz from row_offsets {}",
54 col_indices.len(),
55 nnz
56 ));
57 }
58 Ok(Self {
59 row_offsets,
60 col_indices,
61 values,
62 n_rows,
63 n_cols,
64 })
65 }
66
67 pub fn nnz(&self) -> usize {
68 self.values.len()
69 }
70
71 pub fn from_dense(dense: &[f64], n_rows: usize, n_cols: usize, threshold: f64) -> Self {
73 let mut row_offsets = Vec::with_capacity(n_rows + 1);
74 let mut col_indices = Vec::new();
75 let mut values = Vec::new();
76
77 row_offsets.push(0);
78 for i in 0..n_rows {
79 for j in 0..n_cols {
80 let v = dense[i * n_cols + j];
81 if v.abs() > threshold {
82 col_indices.push(j);
83 values.push(v);
84 }
85 }
86 row_offsets.push(col_indices.len());
87 }
88
89 Self {
90 row_offsets,
91 col_indices,
92 values,
93 n_rows,
94 n_cols,
95 }
96 }
97
98 pub fn to_dense(&self) -> Vec<f64> {
100 let mut dense = vec![0.0_f64; self.n_rows * self.n_cols];
101 for i in 0..self.n_rows {
102 for idx in self.row_offsets[i]..self.row_offsets[i + 1] {
103 dense[i * self.n_cols + self.col_indices[idx]] = self.values[idx];
104 }
105 }
106 dense
107 }
108
109 fn row_sum(&self, i: usize) -> f64 {
111 let mut s = 0.0_f64;
112 for idx in self.row_offsets[i]..self.row_offsets[i + 1] {
113 s += self.values[idx];
114 }
115 s
116 }
117}
118
119pub enum AdjStorage {
121 Dense { adj: Vec<f64> },
122 Sparse { csr: CsrMatrix },
123}
124
125pub struct StochasticGraphLayer {
126 pub n_nodes: usize,
127 pub n_features: usize,
128 pub storage: AdjStorage,
129 pub weights: Vec<f64>,
130 pub degrees: Vec<f64>,
131}
132
133fn random_weights(n_features: usize, seed: u64) -> Vec<f64> {
134 let mut rng = ChaCha8Rng::seed_from_u64(seed);
135 let mut weights = vec![0.0_f64; n_features * n_features];
136 for w in &mut weights {
137 *w = rng.random::<f64>();
138 }
139 weights
140}
141
142fn dense_degrees(adj: &[f64], n: usize) -> Vec<f64> {
143 let mut degrees = vec![0.0_f64; n];
144 for i in 0..n {
145 let mut sum = 0.0_f64;
146 for j in 0..n {
147 sum += adj[i * n + j];
148 }
149 degrees[i] = sum;
150 }
151 degrees
152}
153
154fn csr_degrees(csr: &CsrMatrix) -> Vec<f64> {
155 (0..csr.n_rows).map(|i| csr.row_sum(i)).collect()
156}
157
158impl StochasticGraphLayer {
159 pub fn new(adj_flat: Vec<f64>, n_nodes: usize, n_features: usize, seed: u64) -> Self {
161 assert_eq!(
162 adj_flat.len(),
163 n_nodes * n_nodes,
164 "adj_flat must have length n_nodes * n_nodes",
165 );
166 let degrees = dense_degrees(&adj_flat, n_nodes);
167 Self {
168 n_nodes,
169 n_features,
170 storage: AdjStorage::Dense { adj: adj_flat },
171 weights: random_weights(n_features, seed),
172 degrees,
173 }
174 }
175
176 pub fn new_sparse(csr: CsrMatrix, n_features: usize, seed: u64) -> Result<Self, String> {
178 if csr.n_rows != csr.n_cols {
179 return Err(format!(
180 "CSR must be square, got {}x{}",
181 csr.n_rows, csr.n_cols
182 ));
183 }
184 let n_nodes = csr.n_rows;
185 let degrees = csr_degrees(&csr);
186 Ok(Self {
187 n_nodes,
188 n_features,
189 storage: AdjStorage::Sparse { csr },
190 weights: random_weights(n_features, seed),
191 degrees,
192 })
193 }
194
195 pub fn from_dense_auto(
197 adj_flat: Vec<f64>,
198 n_nodes: usize,
199 n_features: usize,
200 seed: u64,
201 density_threshold: f64,
202 ) -> Self {
203 assert_eq!(adj_flat.len(), n_nodes * n_nodes);
204 let total = (n_nodes * n_nodes) as f64;
205 let nnz = adj_flat.iter().filter(|v| v.abs() > 1e-15).count() as f64;
206 let density = nnz / total;
207
208 if density < density_threshold {
209 let csr = CsrMatrix::from_dense(&adj_flat, n_nodes, n_nodes, 1e-15);
210 let degrees = csr_degrees(&csr);
211 Self {
212 n_nodes,
213 n_features,
214 storage: AdjStorage::Sparse { csr },
215 weights: random_weights(n_features, seed),
216 degrees,
217 }
218 } else {
219 Self::new(adj_flat, n_nodes, n_features, seed)
220 }
221 }
222
223 pub fn is_sparse(&self) -> bool {
225 matches!(self.storage, AdjStorage::Sparse { .. })
226 }
227
228 fn validate_features(&self, node_features: &[f64]) -> Result<(), String> {
229 if node_features.len() != self.n_nodes * self.n_features {
230 return Err(format!(
231 "node_features length mismatch: got {}, expected {}.",
232 node_features.len(),
233 self.n_nodes * self.n_features
234 ));
235 }
236 Ok(())
237 }
238
239 fn aggregate_and_transform(&self, agg_flat: &[f64]) -> Vec<f64> {
241 let out_rows: Vec<Vec<f64>> = (0..self.n_nodes)
242 .into_par_iter()
243 .map(|i| {
244 let agg = &agg_flat[i * self.n_features..(i + 1) * self.n_features];
245 let mut out = vec![0.0_f64; self.n_features];
246 for (f_out, out_val) in out.iter_mut().enumerate().take(self.n_features) {
247 let mut acc = 0.0_f64;
248 for (g, agg_val) in agg.iter().enumerate().take(self.n_features) {
249 acc += *agg_val * self.weights[g * self.n_features + f_out];
250 }
251 *out_val = acc.tanh();
252 }
253 out
254 })
255 .collect();
256 let mut flat = Vec::with_capacity(self.n_nodes * self.n_features);
257 for row in out_rows {
258 flat.extend(row);
259 }
260 flat
261 }
262
263 pub fn forward(&self, node_features: &[f64]) -> Result<Vec<f64>, String> {
265 self.validate_features(node_features)?;
266
267 let mut agg = vec![0.0_f64; self.n_nodes * self.n_features];
268
269 match &self.storage {
270 AdjStorage::Dense { adj } => {
271 let agg_rows: Vec<Vec<f64>> = (0..self.n_nodes)
272 .into_par_iter()
273 .map(|i| {
274 let mut row = vec![0.0_f64; self.n_features];
275 for f in 0..self.n_features {
276 let mut acc = 0.0_f64;
277 for j in 0..self.n_nodes {
278 acc += adj[i * self.n_nodes + j]
279 * node_features[j * self.n_features + f];
280 }
281 row[f] = acc;
282 }
283 if self.degrees[i] != 0.0 {
284 for x in &mut row {
285 *x /= self.degrees[i];
286 }
287 }
288 row
289 })
290 .collect();
291 for (i, row) in agg_rows.into_iter().enumerate() {
292 agg[i * self.n_features..(i + 1) * self.n_features].copy_from_slice(&row);
293 }
294 }
295 AdjStorage::Sparse { csr } => {
296 let agg_rows: Vec<Vec<f64>> = (0..self.n_nodes)
297 .into_par_iter()
298 .map(|i| {
299 let mut row = vec![0.0_f64; self.n_features];
300 for idx in csr.row_offsets[i]..csr.row_offsets[i + 1] {
301 let j = csr.col_indices[idx];
302 let a_ij = csr.values[idx];
303 for f in 0..self.n_features {
304 row[f] += a_ij * node_features[j * self.n_features + f];
305 }
306 }
307 if self.degrees[i] != 0.0 {
308 for x in &mut row {
309 *x /= self.degrees[i];
310 }
311 }
312 row
313 })
314 .collect();
315 for (i, row) in agg_rows.into_iter().enumerate() {
316 agg[i * self.n_features..(i + 1) * self.n_features].copy_from_slice(&row);
317 }
318 }
319 }
320
321 Ok(self.aggregate_and_transform(&agg))
322 }
323
324 pub fn forward_sc(
329 &self,
330 node_features: &[f64],
331 length: usize,
332 seed: u64,
333 ) -> Result<Vec<f64>, String> {
334 self.validate_features(node_features)?;
335 if length == 0 {
336 return Err("length must be > 0 for SC mode.".to_string());
337 }
338
339 let mut rng = ChaCha8Rng::seed_from_u64(seed);
340 let words = length.div_ceil(64);
341
342 let feat_packed = crate::bitstream::encode_matrix_prob_to_packed(
343 node_features,
344 self.n_nodes,
345 self.n_features,
346 length,
347 words,
348 &mut rng,
349 );
350
351 let mut agg = vec![0.0_f64; self.n_nodes * self.n_features];
352
353 match &self.storage {
354 AdjStorage::Dense { adj } => {
355 let adj_packed = crate::bitstream::encode_matrix_prob_to_packed(
356 adj,
357 self.n_nodes,
358 self.n_nodes,
359 length,
360 words,
361 &mut rng,
362 );
363 for i in 0..self.n_nodes {
364 for f in 0..self.n_features {
365 let mut pop_total = 0_u64;
366 for j in 0..self.n_nodes {
367 let a = &adj_packed[i * self.n_nodes + j];
368 let b = &feat_packed[j * self.n_features + f];
369 for w in 0..words {
370 pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
371 }
372 }
373 agg[i * self.n_features + f] = pop_total as f64 / length as f64;
374 }
375 }
376 }
377 AdjStorage::Sparse { csr } => {
378 let nnz = csr.nnz();
380 let adj_vals_clamped: Vec<f64> =
381 csr.values.iter().map(|v| v.clamp(0.0, 1.0)).collect();
382 let adj_packed = crate::bitstream::encode_matrix_prob_to_packed(
383 &adj_vals_clamped,
384 1,
385 nnz,
386 length,
387 words,
388 &mut rng,
389 );
390 for i in 0..self.n_nodes {
391 #[allow(clippy::needless_range_loop)]
392 for idx in csr.row_offsets[i]..csr.row_offsets[i + 1] {
393 let j = csr.col_indices[idx];
394 let a = &adj_packed[idx];
395 for f in 0..self.n_features {
396 let b = &feat_packed[j * self.n_features + f];
397 let mut pop = 0_u64;
398 for w in 0..words {
399 pop += crate::bitstream::swar_popcount_word(a[w] & b[w]);
400 }
401 agg[i * self.n_features + f] += pop as f64 / length as f64;
402 }
403 }
404 }
405 }
406 }
407
408 for i in 0..self.n_nodes {
409 if self.degrees[i] != 0.0 {
410 for f in 0..self.n_features {
411 agg[i * self.n_features + f] /= self.degrees[i];
412 }
413 }
414 }
415
416 let agg_packed = crate::bitstream::encode_matrix_prob_to_packed(
417 &agg,
418 self.n_nodes,
419 self.n_features,
420 length,
421 words,
422 &mut rng,
423 );
424 let w_clamped: Vec<f64> = self.weights.iter().map(|w| w.clamp(0.0, 1.0)).collect();
425 let w_packed = crate::bitstream::encode_matrix_prob_to_packed(
426 &w_clamped,
427 self.n_features,
428 self.n_features,
429 length,
430 words,
431 &mut rng,
432 );
433
434 let mut out = Vec::with_capacity(self.n_nodes * self.n_features);
435 for i in 0..self.n_nodes {
436 for f_out in 0..self.n_features {
437 let mut pop_total = 0_u64;
438 for g in 0..self.n_features {
439 let a = &agg_packed[i * self.n_features + g];
440 let b = &w_packed[g * self.n_features + f_out];
441 for w in 0..words {
442 pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
443 }
444 }
445 out.push((pop_total as f64 / length as f64).tanh());
446 }
447 }
448
449 Ok(out)
450 }
451
452 pub fn get_weights(&self) -> Vec<f64> {
453 self.weights.clone()
454 }
455
456 pub fn set_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
457 if weights.len() != self.n_features * self.n_features {
458 return Err(format!(
459 "weights length mismatch: got {}, expected {}.",
460 weights.len(),
461 self.n_features * self.n_features
462 ));
463 }
464 self.weights = weights;
465 Ok(())
466 }
467}