1#[derive(Clone, Debug)]
14pub struct FusionLayer {
15 pub weights: Vec<f64>,
17 pub n_modalities: usize,
18 pub n_features: usize,
19}
20
21impl FusionLayer {
22 pub fn new(raw_weights: &[f64], n_features: usize) -> Self {
24 let total: f64 = raw_weights.iter().sum();
25 let weights: Vec<f64> = if total > 0.0 {
26 raw_weights.iter().map(|w| w / total).collect()
27 } else {
28 vec![1.0 / raw_weights.len() as f64; raw_weights.len()]
29 };
30 Self {
31 n_modalities: weights.len(),
32 weights,
33 n_features,
34 }
35 }
36
37 pub fn forward(&self, inputs: &[f64]) -> Vec<f64> {
40 assert_eq!(inputs.len(), self.n_modalities * self.n_features);
41 let mut out = vec![0.0; self.n_features];
42 for (m, &w) in self.weights.iter().enumerate() {
43 let offset = m * self.n_features;
44 for f in 0..self.n_features {
45 out[f] += inputs[offset + f] * w;
46 }
47 }
48 out
49 }
50}
51
52#[derive(Clone, Debug)]
56pub struct MemristiveLayer {
57 pub inner: crate::layer::DenseLayer,
58 pub stuck_mask: Vec<bool>,
59 pub stuck_values: Vec<f64>,
60}
61
62impl MemristiveLayer {
63 pub fn new(
64 n_inputs: usize,
65 n_neurons: usize,
66 length: usize,
67 seed: u64,
68 stuck_rate: f64,
69 variability: f64,
70 ) -> Self {
71 use rand::{RngExt, SeedableRng};
72 use rand_xoshiro::Xoshiro256PlusPlus;
73
74 let mut layer = crate::layer::DenseLayer::new(n_inputs, n_neurons, length, seed);
75 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(0xDEFEC7));
76
77 let total = n_neurons * n_inputs;
78 let mut stuck_mask = vec![false; total];
79 let mut stuck_values = vec![0.0; total];
80
81 for i in 0..n_neurons {
82 for j in 0..n_inputs {
83 let idx = i * n_inputs + j;
84 let noise: f64 = rng.random::<f64>() * 2.0 * variability - variability;
86 layer.weights[i][j] = (layer.weights[i][j] + noise).clamp(0.0, 1.0);
87
88 if rng.random::<f64>() < stuck_rate {
90 stuck_mask[idx] = true;
91 stuck_values[idx] = if rng.random::<bool>() { 1.0 } else { 0.0 };
92 layer.weights[i][j] = stuck_values[idx];
93 }
94 }
95 }
96 layer.refresh_packed_weights();
97
98 Self {
99 inner: layer,
100 stuck_mask,
101 stuck_values,
102 }
103 }
104
105 pub fn forward(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
106 self.inner.forward_fused(input_values, seed)
107 }
108}
109
110#[derive(Clone, Debug)]
112pub struct LearningLayer {
113 pub n_inputs: usize,
114 pub n_neurons: usize,
115 pub weights: Vec<Vec<f64>>,
116 pub learning_rate: f64,
117}
118
119impl LearningLayer {
120 pub fn new(n_inputs: usize, n_neurons: usize, learning_rate: f64, seed: u64) -> Self {
121 use rand::{RngExt, SeedableRng};
122 use rand_chacha::ChaCha8Rng;
123
124 let mut rng = ChaCha8Rng::seed_from_u64(seed);
125 let weights: Vec<Vec<f64>> = (0..n_neurons)
126 .map(|_| (0..n_inputs).map(|_| rng.random::<f64>()).collect())
127 .collect();
128 Self {
129 n_inputs,
130 n_neurons,
131 weights,
132 learning_rate,
133 }
134 }
135
136 #[allow(clippy::needless_range_loop)]
139 pub fn step(&mut self, input_spikes: &[bool], threshold: f64) -> Vec<bool> {
140 assert_eq!(input_spikes.len(), self.n_inputs);
141 let mut output = vec![false; self.n_neurons];
142
143 for i in 0..self.n_neurons {
144 let mut current = 0.0;
145 for j in 0..self.n_inputs {
146 if input_spikes[j] {
147 current += self.weights[i][j];
148 }
149 }
150 output[i] = current > threshold;
151
152 for j in 0..self.n_inputs {
154 if input_spikes[j] && output[i] {
155 self.weights[i][j] = (self.weights[i][j] + self.learning_rate).min(1.0);
157 } else if input_spikes[j] && !output[i] {
158 self.weights[i][j] = (self.weights[i][j] - self.learning_rate * 0.5).max(0.0);
160 }
161 }
162 }
163 output
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn fusion_weighted_sum() {
173 let layer = FusionLayer::new(&[0.7, 0.3], 4);
174 let inputs = vec![
175 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, ];
178 let out = layer.forward(&inputs);
179 assert_eq!(out.len(), 4);
180 assert!((out[0] - 0.7).abs() < 1e-10);
181 }
182
183 #[test]
184 fn fusion_equal_weights() {
185 let layer = FusionLayer::new(&[1.0, 1.0], 2);
186 let inputs = vec![0.6, 0.4, 0.2, 0.8];
187 let out = layer.forward(&inputs);
188 assert!((out[0] - 0.4).abs() < 1e-10);
189 assert!((out[1] - 0.6).abs() < 1e-10);
190 }
191
192 #[test]
193 fn memristive_forward() {
194 let layer = MemristiveLayer::new(4, 2, 256, 42, 0.05, 0.01);
195 let out = layer.forward(&[0.5, 0.5, 0.5, 0.5], 99).unwrap();
196 assert_eq!(out.len(), 2);
197 }
198
199 #[test]
200 fn learning_layer_fires() {
201 let mut layer = LearningLayer::new(4, 2, 0.01, 42);
202 let spikes = vec![true, true, true, true];
203 let out = layer.step(&spikes, 0.5);
204 assert_eq!(out.len(), 2);
205 }
206
207 #[test]
208 fn learning_layer_weights_change() {
209 let mut layer = LearningLayer::new(4, 2, 0.1, 42);
210 let initial = layer.weights.clone();
211 for _ in 0..50 {
212 layer.step(&[true, true, false, false], 0.3);
213 }
214 assert_ne!(layer.weights, initial);
215 }
216}