1use std::error::Error;
16use std::fmt;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum ScForwardError {
21 InvalidLength(usize),
22 WeightLengthMismatch { expected: usize, actual: usize },
23 InputLengthMismatch { expected: usize, actual: usize },
24 ProbabilityOutOfRange,
25}
26
27impl fmt::Display for ScForwardError {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 match self {
30 Self::InvalidLength(length) => write!(f, "length must be positive, got {length}"),
31 Self::WeightLengthMismatch { expected, actual } => write!(
32 f,
33 "weights_packed length must be n_out*n_in*n_words ({expected}), got {actual}"
34 ),
35 Self::InputLengthMismatch { expected, actual } => write!(
36 f,
37 "input_probs length must be n_in ({expected}), got {actual}"
38 ),
39 Self::ProbabilityOutOfRange => write!(f, "input_probs must lie in [0, 1]"),
40 }
41 }
42}
43
44impl Error for ScForwardError {}
45
46fn input_seed(base_seed: u64, input_idx: usize) -> u16 {
48 let masked = (base_seed.wrapping_add(input_idx as u64) & 0xFFFF) as u16;
49 if masked == 0 {
50 1
51 } else {
52 masked
53 }
54}
55
56fn lfsr_encode_packed(p: f64, length: usize, seed: u16, out: &mut [u64]) {
59 let x_value = (p * 65536.0).round_ties_even().clamp(0.0, 65536.0) as u32;
61 let mut reg: u16 = seed;
62 for tap in 0..length {
63 if u32::from(reg) < x_value {
64 out[tap / 64] |= 1_u64 << (tap % 64);
65 }
66 let feedback = ((reg >> 15) ^ (reg >> 13) ^ (reg >> 12) ^ (reg >> 10)) & 1;
67 reg = (reg << 1) | feedback;
68 }
69}
70
71pub fn sc_forward_packed(
77 weights_packed: &[u64],
78 n_out: usize,
79 n_in: usize,
80 n_words: usize,
81 input_probs: &[f64],
82 length: usize,
83 seed: u64,
84) -> Result<Vec<f64>, ScForwardError> {
85 if length == 0 || n_words != length.div_ceil(64) {
86 return Err(ScForwardError::InvalidLength(length));
87 }
88 let expected_weights = n_out * n_in * n_words;
89 if weights_packed.len() != expected_weights {
90 return Err(ScForwardError::WeightLengthMismatch {
91 expected: expected_weights,
92 actual: weights_packed.len(),
93 });
94 }
95 if input_probs.len() != n_in {
96 return Err(ScForwardError::InputLengthMismatch {
97 expected: n_in,
98 actual: input_probs.len(),
99 });
100 }
101 if input_probs.iter().any(|&p| !(0.0..=1.0).contains(&p)) {
102 return Err(ScForwardError::ProbabilityOutOfRange);
103 }
104
105 let mut input_words = vec![0_u64; n_in * n_words];
106 for input_idx in 0..n_in {
107 let seed_i = input_seed(seed, input_idx);
108 lfsr_encode_packed(
109 input_probs[input_idx],
110 length,
111 seed_i,
112 &mut input_words[input_idx * n_words..(input_idx + 1) * n_words],
113 );
114 }
115
116 let length_f = length as f64;
117 let mut outputs = vec![0.0_f64; n_out];
118 for (output_idx, output) in outputs.iter_mut().enumerate() {
119 let mut accumulator: u64 = 0;
120 for input_idx in 0..n_in {
121 let weight_base = (output_idx * n_in + input_idx) * n_words;
122 let input_base = input_idx * n_words;
123 for word in 0..n_words {
124 accumulator += u64::from(
125 (weights_packed[weight_base + word] & input_words[input_base + word])
126 .count_ones(),
127 );
128 }
129 }
130 *output = accumulator as f64 / length_f;
131 }
132 Ok(outputs)
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn zero_probability_emits_no_ones() {
141 let mut words = vec![0_u64; 2];
142 lfsr_encode_packed(0.0, 128, 0xACE1, &mut words);
143 assert_eq!(words[0].count_ones() + words[1].count_ones(), 0);
144 }
145
146 #[test]
147 fn full_probability_emits_all_ones() {
148 let mut words = vec![0_u64; 1];
149 lfsr_encode_packed(1.0, 64, 0xACE1, &mut words);
150 assert_eq!(words[0].count_ones(), 64);
151 }
152
153 #[test]
154 fn all_ones_weights_recover_input_proportion() {
155 let length = 4096;
158 let n_words = length / 64;
159 let weights = vec![u64::MAX; n_words];
160 let result = sc_forward_packed(&weights, 1, 1, n_words, &[0.25], length, 0xACE1).unwrap();
161 assert!((result[0] - 0.25).abs() < 0.02);
162 }
163
164 #[test]
165 fn rejects_bad_shapes() {
166 assert_eq!(
167 sc_forward_packed(&[0], 1, 1, 1, &[0.0], 0, 1).unwrap_err(),
168 ScForwardError::InvalidLength(0)
169 );
170 assert!(matches!(
171 sc_forward_packed(&[0, 0], 1, 1, 1, &[0.0], 64, 1).unwrap_err(),
172 ScForwardError::WeightLengthMismatch { .. }
173 ));
174 assert!(matches!(
175 sc_forward_packed(&[0], 1, 1, 1, &[0.0, 0.0], 64, 1).unwrap_err(),
176 ScForwardError::InputLengthMismatch { .. }
177 ));
178 assert_eq!(
179 sc_forward_packed(&[0], 1, 1, 1, &[1.5], 64, 1).unwrap_err(),
180 ScForwardError::ProbabilityOutOfRange
181 );
182 }
183}