1use rand::{RngExt, SeedableRng};
15use rand_chacha::ChaCha8Rng;
16use rayon::prelude::*;
17
18#[derive(Clone, Debug)]
20pub struct CsrMatrix {
21 pub row_offsets: Vec<usize>,
23 pub col_indices: Vec<usize>,
24 pub values: Vec<f64>,
25 pub n_rows: usize,
26 pub n_cols: usize,
27}
28
29impl CsrMatrix {
30 pub fn new(
31 row_offsets: Vec<usize>,
32 col_indices: Vec<usize>,
33 values: Vec<f64>,
34 n_rows: usize,
35 n_cols: usize,
36 ) -> Result<Self, String> {
37 if row_offsets.len() != n_rows + 1 {
38 return Err(format!(
39 "row_offsets length {} != n_rows + 1 = {}",
40 row_offsets.len(),
41 n_rows + 1
42 ));
43 }
44 if col_indices.len() != values.len() {
45 return Err(format!(
46 "col_indices length {} != values length {}",
47 col_indices.len(),
48 values.len()
49 ));
50 }
51 let nnz = *row_offsets.last().ok_or("row_offsets must not be empty")?;
52 if col_indices.len() != nnz {
53 return Err(format!(
54 "col_indices length {} != nnz from row_offsets {}",
55 col_indices.len(),
56 nnz
57 ));
58 }
59 Ok(Self {
60 row_offsets,
61 col_indices,
62 values,
63 n_rows,
64 n_cols,
65 })
66 }
67
68 pub fn nnz(&self) -> usize {
69 self.values.len()
70 }
71
72 pub fn from_dense(dense: &[f64], n_rows: usize, n_cols: usize, threshold: f64) -> Self {
74 let mut row_offsets = Vec::with_capacity(n_rows + 1);
75 let mut col_indices = Vec::new();
76 let mut values = Vec::new();
77
78 row_offsets.push(0);
79 for i in 0..n_rows {
80 for j in 0..n_cols {
81 let v = dense[i * n_cols + j];
82 if v.abs() > threshold {
83 col_indices.push(j);
84 values.push(v);
85 }
86 }
87 row_offsets.push(col_indices.len());
88 }
89
90 Self {
91 row_offsets,
92 col_indices,
93 values,
94 n_rows,
95 n_cols,
96 }
97 }
98
99 pub fn to_dense(&self) -> Vec<f64> {
101 let mut dense = vec![0.0_f64; self.n_rows * self.n_cols];
102 for i in 0..self.n_rows {
103 for idx in self.row_offsets[i]..self.row_offsets[i + 1] {
104 dense[i * self.n_cols + self.col_indices[idx]] = self.values[idx];
105 }
106 }
107 dense
108 }
109
110 fn row_sum(&self, i: usize) -> f64 {
112 let mut s = 0.0_f64;
113 for idx in self.row_offsets[i]..self.row_offsets[i + 1] {
114 s += self.values[idx];
115 }
116 s
117 }
118}
119
120pub enum AdjStorage {
122 Dense { adj: Vec<f64> },
123 Sparse { csr: CsrMatrix },
124}
125
126pub struct StochasticGraphLayer {
127 pub n_nodes: usize,
128 pub n_features: usize,
129 pub storage: AdjStorage,
130 pub weights: Vec<f64>,
131 pub degrees: Vec<f64>,
132}
133
134fn random_weights(n_features: usize, seed: u64) -> Vec<f64> {
135 let mut rng = ChaCha8Rng::seed_from_u64(seed);
136 let mut weights = vec![0.0_f64; n_features * n_features];
137 for w in &mut weights {
138 *w = rng.random::<f64>();
139 }
140 weights
141}
142
143fn dense_degrees(adj: &[f64], n: usize) -> Vec<f64> {
144 let mut degrees = vec![0.0_f64; n];
145 for i in 0..n {
146 let mut sum = 0.0_f64;
147 for j in 0..n {
148 sum += adj[i * n + j];
149 }
150 degrees[i] = sum;
151 }
152 degrees
153}
154
155fn csr_degrees(csr: &CsrMatrix) -> Vec<f64> {
156 (0..csr.n_rows).map(|i| csr.row_sum(i)).collect()
157}
158
159impl StochasticGraphLayer {
160 pub fn new(adj_flat: Vec<f64>, n_nodes: usize, n_features: usize, seed: u64) -> Self {
162 assert_eq!(
163 adj_flat.len(),
164 n_nodes * n_nodes,
165 "adj_flat must have length n_nodes * n_nodes",
166 );
167 let degrees = dense_degrees(&adj_flat, n_nodes);
168 Self {
169 n_nodes,
170 n_features,
171 storage: AdjStorage::Dense { adj: adj_flat },
172 weights: random_weights(n_features, seed),
173 degrees,
174 }
175 }
176
177 pub fn new_sparse(csr: CsrMatrix, n_features: usize, seed: u64) -> Result<Self, String> {
179 if csr.n_rows != csr.n_cols {
180 return Err(format!(
181 "CSR must be square, got {}x{}",
182 csr.n_rows, csr.n_cols
183 ));
184 }
185 let n_nodes = csr.n_rows;
186 let degrees = csr_degrees(&csr);
187 Ok(Self {
188 n_nodes,
189 n_features,
190 storage: AdjStorage::Sparse { csr },
191 weights: random_weights(n_features, seed),
192 degrees,
193 })
194 }
195
196 pub fn from_dense_auto(
198 adj_flat: Vec<f64>,
199 n_nodes: usize,
200 n_features: usize,
201 seed: u64,
202 density_threshold: f64,
203 ) -> Self {
204 assert_eq!(adj_flat.len(), n_nodes * n_nodes);
205 let total = (n_nodes * n_nodes) as f64;
206 let nnz = adj_flat.iter().filter(|v| v.abs() > 1e-15).count() as f64;
207 let density = nnz / total;
208
209 if density < density_threshold {
210 let csr = CsrMatrix::from_dense(&adj_flat, n_nodes, n_nodes, 1e-15);
211 let degrees = csr_degrees(&csr);
212 Self {
213 n_nodes,
214 n_features,
215 storage: AdjStorage::Sparse { csr },
216 weights: random_weights(n_features, seed),
217 degrees,
218 }
219 } else {
220 Self::new(adj_flat, n_nodes, n_features, seed)
221 }
222 }
223
224 pub fn is_sparse(&self) -> bool {
226 matches!(self.storage, AdjStorage::Sparse { .. })
227 }
228
229 fn validate_features(&self, node_features: &[f64]) -> Result<(), String> {
230 if node_features.len() != self.n_nodes * self.n_features {
231 return Err(format!(
232 "node_features length mismatch: got {}, expected {}.",
233 node_features.len(),
234 self.n_nodes * self.n_features
235 ));
236 }
237 Ok(())
238 }
239
240 fn aggregate_and_transform(&self, agg_flat: &[f64]) -> Vec<f64> {
242 let out_rows: Vec<Vec<f64>> = (0..self.n_nodes)
243 .into_par_iter()
244 .map(|i| {
245 let agg = &agg_flat[i * self.n_features..(i + 1) * self.n_features];
246 let mut out = vec![0.0_f64; self.n_features];
247 for (f_out, out_val) in out.iter_mut().enumerate().take(self.n_features) {
248 let mut acc = 0.0_f64;
249 for (g, agg_val) in agg.iter().enumerate().take(self.n_features) {
250 acc += *agg_val * self.weights[g * self.n_features + f_out];
251 }
252 *out_val = acc.tanh();
253 }
254 out
255 })
256 .collect();
257 let mut flat = Vec::with_capacity(self.n_nodes * self.n_features);
258 for row in out_rows {
259 flat.extend(row);
260 }
261 flat
262 }
263
264 pub fn forward(&self, node_features: &[f64]) -> Result<Vec<f64>, String> {
266 self.validate_features(node_features)?;
267
268 let mut agg = vec![0.0_f64; self.n_nodes * self.n_features];
269
270 match &self.storage {
271 AdjStorage::Dense { adj } => {
272 let agg_rows: Vec<Vec<f64>> = (0..self.n_nodes)
273 .into_par_iter()
274 .map(|i| {
275 let mut row = vec![0.0_f64; self.n_features];
276 for f in 0..self.n_features {
277 let mut acc = 0.0_f64;
278 for j in 0..self.n_nodes {
279 acc += adj[i * self.n_nodes + j]
280 * node_features[j * self.n_features + f];
281 }
282 row[f] = acc;
283 }
284 if self.degrees[i] != 0.0 {
285 for x in &mut row {
286 *x /= self.degrees[i];
287 }
288 }
289 row
290 })
291 .collect();
292 for (i, row) in agg_rows.into_iter().enumerate() {
293 agg[i * self.n_features..(i + 1) * self.n_features].copy_from_slice(&row);
294 }
295 }
296 AdjStorage::Sparse { csr } => {
297 let agg_rows: Vec<Vec<f64>> = (0..self.n_nodes)
298 .into_par_iter()
299 .map(|i| {
300 let mut row = vec![0.0_f64; self.n_features];
301 for idx in csr.row_offsets[i]..csr.row_offsets[i + 1] {
302 let j = csr.col_indices[idx];
303 let a_ij = csr.values[idx];
304 for f in 0..self.n_features {
305 row[f] += a_ij * node_features[j * self.n_features + f];
306 }
307 }
308 if self.degrees[i] != 0.0 {
309 for x in &mut row {
310 *x /= self.degrees[i];
311 }
312 }
313 row
314 })
315 .collect();
316 for (i, row) in agg_rows.into_iter().enumerate() {
317 agg[i * self.n_features..(i + 1) * self.n_features].copy_from_slice(&row);
318 }
319 }
320 }
321
322 Ok(self.aggregate_and_transform(&agg))
323 }
324
325 pub fn forward_sc(
330 &self,
331 node_features: &[f64],
332 length: usize,
333 seed: u64,
334 ) -> Result<Vec<f64>, String> {
335 self.validate_features(node_features)?;
336 if length == 0 {
337 return Err("length must be > 0 for SC mode.".to_string());
338 }
339
340 let mut rng = ChaCha8Rng::seed_from_u64(seed);
341 let words = length.div_ceil(64);
342
343 let feat_packed = crate::bitstream::encode_matrix_prob_to_packed(
344 node_features,
345 self.n_nodes,
346 self.n_features,
347 length,
348 words,
349 &mut rng,
350 );
351
352 let mut agg = vec![0.0_f64; self.n_nodes * self.n_features];
353
354 match &self.storage {
355 AdjStorage::Dense { adj } => {
356 let adj_packed = crate::bitstream::encode_matrix_prob_to_packed(
357 adj,
358 self.n_nodes,
359 self.n_nodes,
360 length,
361 words,
362 &mut rng,
363 );
364 for i in 0..self.n_nodes {
365 for f in 0..self.n_features {
366 let mut pop_total = 0_u64;
367 for j in 0..self.n_nodes {
368 let a = &adj_packed[i * self.n_nodes + j];
369 let b = &feat_packed[j * self.n_features + f];
370 for w in 0..words {
371 pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
372 }
373 }
374 agg[i * self.n_features + f] = pop_total as f64 / length as f64;
375 }
376 }
377 }
378 AdjStorage::Sparse { csr } => {
379 let nnz = csr.nnz();
381 let adj_vals_clamped: Vec<f64> =
382 csr.values.iter().map(|v| v.clamp(0.0, 1.0)).collect();
383 let adj_packed = crate::bitstream::encode_matrix_prob_to_packed(
384 &adj_vals_clamped,
385 1,
386 nnz,
387 length,
388 words,
389 &mut rng,
390 );
391 for i in 0..self.n_nodes {
392 #[allow(clippy::needless_range_loop)]
393 for idx in csr.row_offsets[i]..csr.row_offsets[i + 1] {
394 let j = csr.col_indices[idx];
395 let a = &adj_packed[idx];
396 for f in 0..self.n_features {
397 let b = &feat_packed[j * self.n_features + f];
398 let mut pop = 0_u64;
399 for w in 0..words {
400 pop += crate::bitstream::swar_popcount_word(a[w] & b[w]);
401 }
402 agg[i * self.n_features + f] += pop as f64 / length as f64;
403 }
404 }
405 }
406 }
407 }
408
409 for i in 0..self.n_nodes {
410 if self.degrees[i] != 0.0 {
411 for f in 0..self.n_features {
412 agg[i * self.n_features + f] /= self.degrees[i];
413 }
414 }
415 }
416
417 let agg_packed = crate::bitstream::encode_matrix_prob_to_packed(
418 &agg,
419 self.n_nodes,
420 self.n_features,
421 length,
422 words,
423 &mut rng,
424 );
425 let w_clamped: Vec<f64> = self.weights.iter().map(|w| w.clamp(0.0, 1.0)).collect();
426 let w_packed = crate::bitstream::encode_matrix_prob_to_packed(
427 &w_clamped,
428 self.n_features,
429 self.n_features,
430 length,
431 words,
432 &mut rng,
433 );
434
435 let mut out = Vec::with_capacity(self.n_nodes * self.n_features);
436 for i in 0..self.n_nodes {
437 for f_out in 0..self.n_features {
438 let mut pop_total = 0_u64;
439 for g in 0..self.n_features {
440 let a = &agg_packed[i * self.n_features + g];
441 let b = &w_packed[g * self.n_features + f_out];
442 for w in 0..words {
443 pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
444 }
445 }
446 out.push((pop_total as f64 / length as f64).tanh());
447 }
448 }
449
450 Ok(out)
451 }
452
453 pub fn get_weights(&self) -> Vec<f64> {
454 self.weights.clone()
455 }
456
457 pub fn set_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
458 if weights.len() != self.n_features * self.n_features {
459 return Err(format!(
460 "weights length mismatch: got {}, expected {}.",
461 weights.len(),
462 self.n_features * self.n_features
463 ));
464 }
465 self.weights = weights;
466 Ok(())
467 }
468}