Skip to main content

sc_neurocore_engine/
sc_inference.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// Copyright (C) 2020-2026 Miroslav Sotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore - Public SC inference over pre-packed weight bitstreams
7
8//! Stochastic forward pass over caller-owned packed weight bitstreams.
9//!
10//! The input encoder is the 16-bit LFSR comparator (`encoder::Lfsr16` semantics:
11//! taps 16, 14, 13, 11; `bit = reg < x_value` then advance), so the result is
12//! bit-identical to the NumPy fallback in
13//! `src/sc_neurocore/accel/sc_inference.py` for a fixed seed.
14
15use std::error::Error;
16use std::fmt;
17
18/// SC inference contract errors.
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum ScForwardError {
21    InvalidLength(usize),
22    WeightLengthMismatch { expected: usize, actual: usize },
23    InputLengthMismatch { expected: usize, actual: usize },
24    ProbabilityOutOfRange,
25}
26
27impl fmt::Display for ScForwardError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            Self::InvalidLength(length) => write!(f, "length must be positive, got {length}"),
31            Self::WeightLengthMismatch { expected, actual } => write!(
32                f,
33                "weights_packed length must be n_out*n_in*n_words ({expected}), got {actual}"
34            ),
35            Self::InputLengthMismatch { expected, actual } => write!(
36                f,
37                "input_probs length must be n_in ({expected}), got {actual}"
38            ),
39            Self::ProbabilityOutOfRange => write!(f, "input_probs must lie in [0, 1]"),
40        }
41    }
42}
43
44impl Error for ScForwardError {}
45
46/// Per-input non-zero 16-bit LFSR seed derived from the base seed.
47fn input_seed(base_seed: u64, input_idx: usize) -> u16 {
48    let masked = (base_seed.wrapping_add(input_idx as u64) & 0xFFFF) as u16;
49    if masked == 0 {
50        1
51    } else {
52        masked
53    }
54}
55
56/// Encode one probability into a `length`-bit LFSR comparator stream, packed
57/// LSB-first into the `out` words (matching `bitstream::pack`).
58fn lfsr_encode_packed(p: f64, length: usize, seed: u16, out: &mut [u64]) {
59    // `round_ties_even` matches NumPy `np.rint`; the ceiling allows p == 1.0.
60    let x_value = (p * 65536.0).round_ties_even().clamp(0.0, 65536.0) as u32;
61    let mut reg: u16 = seed;
62    for tap in 0..length {
63        if u32::from(reg) < x_value {
64            out[tap / 64] |= 1_u64 << (tap % 64);
65        }
66        let feedback = ((reg >> 15) ^ (reg >> 13) ^ (reg >> 12) ^ (reg >> 10)) & 1;
67        reg = (reg << 1) | feedback;
68    }
69}
70
71/// Stochastic forward pass over pre-packed unipolar weight bitstreams.
72///
73/// `weights_packed` is row-major `n_out * n_in * n_words` with
74/// `n_words = ceil(length / 64)`. Returns `n_out` estimates of
75/// `weights @ input_probs`, the AND-then-popcount count divided by `length`.
76pub fn sc_forward_packed(
77    weights_packed: &[u64],
78    n_out: usize,
79    n_in: usize,
80    n_words: usize,
81    input_probs: &[f64],
82    length: usize,
83    seed: u64,
84) -> Result<Vec<f64>, ScForwardError> {
85    if length == 0 || n_words != length.div_ceil(64) {
86        return Err(ScForwardError::InvalidLength(length));
87    }
88    let expected_weights = n_out * n_in * n_words;
89    if weights_packed.len() != expected_weights {
90        return Err(ScForwardError::WeightLengthMismatch {
91            expected: expected_weights,
92            actual: weights_packed.len(),
93        });
94    }
95    if input_probs.len() != n_in {
96        return Err(ScForwardError::InputLengthMismatch {
97            expected: n_in,
98            actual: input_probs.len(),
99        });
100    }
101    if input_probs.iter().any(|&p| !(0.0..=1.0).contains(&p)) {
102        return Err(ScForwardError::ProbabilityOutOfRange);
103    }
104
105    let mut input_words = vec![0_u64; n_in * n_words];
106    for input_idx in 0..n_in {
107        let seed_i = input_seed(seed, input_idx);
108        lfsr_encode_packed(
109            input_probs[input_idx],
110            length,
111            seed_i,
112            &mut input_words[input_idx * n_words..(input_idx + 1) * n_words],
113        );
114    }
115
116    let length_f = length as f64;
117    let mut outputs = vec![0.0_f64; n_out];
118    for (output_idx, output) in outputs.iter_mut().enumerate() {
119        let mut accumulator: u64 = 0;
120        for input_idx in 0..n_in {
121            let weight_base = (output_idx * n_in + input_idx) * n_words;
122            let input_base = input_idx * n_words;
123            for word in 0..n_words {
124                accumulator += u64::from(
125                    (weights_packed[weight_base + word] & input_words[input_base + word])
126                        .count_ones(),
127                );
128            }
129        }
130        *output = accumulator as f64 / length_f;
131    }
132    Ok(outputs)
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn zero_probability_emits_no_ones() {
141        let mut words = vec![0_u64; 2];
142        lfsr_encode_packed(0.0, 128, 0xACE1, &mut words);
143        assert_eq!(words[0].count_ones() + words[1].count_ones(), 0);
144    }
145
146    #[test]
147    fn full_probability_emits_all_ones() {
148        let mut words = vec![0_u64; 1];
149        lfsr_encode_packed(1.0, 64, 0xACE1, &mut words);
150        assert_eq!(words[0].count_ones(), 64);
151    }
152
153    #[test]
154    fn all_ones_weights_recover_input_proportion() {
155        // Weight bitstream of all ones -> AND-popcount/length recovers the input
156        // proportion to within LFSR discretisation.
157        let length = 4096;
158        let n_words = length / 64;
159        let weights = vec![u64::MAX; n_words];
160        let result = sc_forward_packed(&weights, 1, 1, n_words, &[0.25], length, 0xACE1).unwrap();
161        assert!((result[0] - 0.25).abs() < 0.02);
162    }
163
164    #[test]
165    fn rejects_bad_shapes() {
166        assert_eq!(
167            sc_forward_packed(&[0], 1, 1, 1, &[0.0], 0, 1).unwrap_err(),
168            ScForwardError::InvalidLength(0)
169        );
170        assert!(matches!(
171            sc_forward_packed(&[0, 0], 1, 1, 1, &[0.0], 64, 1).unwrap_err(),
172            ScForwardError::WeightLengthMismatch { .. }
173        ));
174        assert!(matches!(
175            sc_forward_packed(&[0], 1, 1, 1, &[0.0, 0.0], 64, 1).unwrap_err(),
176            ScForwardError::InputLengthMismatch { .. }
177        ));
178        assert_eq!(
179            sc_forward_packed(&[0], 1, 1, 1, &[1.5], 64, 1).unwrap_err(),
180            ScForwardError::ProbabilityOutOfRange
181        );
182    }
183}