1#[derive(Clone, Debug)]
15pub struct FusionLayer {
16 pub weights: Vec<f64>,
18 pub n_modalities: usize,
19 pub n_features: usize,
20}
21
22impl FusionLayer {
23 pub fn new(raw_weights: &[f64], n_features: usize) -> Self {
25 let total: f64 = raw_weights.iter().sum();
26 let weights: Vec<f64> = if total > 0.0 {
27 raw_weights.iter().map(|w| w / total).collect()
28 } else {
29 vec![1.0 / raw_weights.len() as f64; raw_weights.len()]
30 };
31 Self {
32 n_modalities: weights.len(),
33 weights,
34 n_features,
35 }
36 }
37
38 pub fn forward(&self, inputs: &[f64]) -> Vec<f64> {
41 assert_eq!(inputs.len(), self.n_modalities * self.n_features);
42 let mut out = vec![0.0; self.n_features];
43 for (m, &w) in self.weights.iter().enumerate() {
44 let offset = m * self.n_features;
45 for f in 0..self.n_features {
46 out[f] += inputs[offset + f] * w;
47 }
48 }
49 out
50 }
51}
52
53#[derive(Clone, Debug)]
57pub struct MemristiveLayer {
58 pub inner: crate::layer::DenseLayer,
59 pub stuck_mask: Vec<bool>,
60 pub stuck_values: Vec<f64>,
61}
62
63impl MemristiveLayer {
64 pub fn new(
65 n_inputs: usize,
66 n_neurons: usize,
67 length: usize,
68 seed: u64,
69 stuck_rate: f64,
70 variability: f64,
71 ) -> Self {
72 use rand::{RngExt, SeedableRng};
73 use rand_xoshiro::Xoshiro256PlusPlus;
74
75 let mut layer = crate::layer::DenseLayer::new(n_inputs, n_neurons, length, seed);
76 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(0xDEFEC7));
77
78 let total = n_neurons * n_inputs;
79 let mut stuck_mask = vec![false; total];
80 let mut stuck_values = vec![0.0; total];
81
82 for i in 0..n_neurons {
83 for j in 0..n_inputs {
84 let idx = i * n_inputs + j;
85 let noise: f64 = rng.random::<f64>() * 2.0 * variability - variability;
87 layer.weights[i][j] = (layer.weights[i][j] + noise).clamp(0.0, 1.0);
88
89 if rng.random::<f64>() < stuck_rate {
91 stuck_mask[idx] = true;
92 stuck_values[idx] = if rng.random::<bool>() { 1.0 } else { 0.0 };
93 layer.weights[i][j] = stuck_values[idx];
94 }
95 }
96 }
97 layer.refresh_packed_weights();
98
99 Self {
100 inner: layer,
101 stuck_mask,
102 stuck_values,
103 }
104 }
105
106 pub fn forward(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
107 self.inner.forward_fused(input_values, seed)
108 }
109}
110
111#[derive(Clone, Debug)]
113pub struct LearningLayer {
114 pub n_inputs: usize,
115 pub n_neurons: usize,
116 pub weights: Vec<Vec<f64>>,
117 pub learning_rate: f64,
118}
119
120impl LearningLayer {
121 pub fn new(n_inputs: usize, n_neurons: usize, learning_rate: f64, seed: u64) -> Self {
122 use rand::{RngExt, SeedableRng};
123 use rand_chacha::ChaCha8Rng;
124
125 let mut rng = ChaCha8Rng::seed_from_u64(seed);
126 let weights: Vec<Vec<f64>> = (0..n_neurons)
127 .map(|_| (0..n_inputs).map(|_| rng.random::<f64>()).collect())
128 .collect();
129 Self {
130 n_inputs,
131 n_neurons,
132 weights,
133 learning_rate,
134 }
135 }
136
137 #[allow(clippy::needless_range_loop)]
140 pub fn step(&mut self, input_spikes: &[bool], threshold: f64) -> Vec<bool> {
141 assert_eq!(input_spikes.len(), self.n_inputs);
142 let mut output = vec![false; self.n_neurons];
143
144 for i in 0..self.n_neurons {
145 let mut current = 0.0;
146 for j in 0..self.n_inputs {
147 if input_spikes[j] {
148 current += self.weights[i][j];
149 }
150 }
151 output[i] = current > threshold;
152
153 for j in 0..self.n_inputs {
155 if input_spikes[j] && output[i] {
156 self.weights[i][j] = (self.weights[i][j] + self.learning_rate).min(1.0);
158 } else if input_spikes[j] && !output[i] {
159 self.weights[i][j] = (self.weights[i][j] - self.learning_rate * 0.5).max(0.0);
161 }
162 }
163 }
164 output
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn fusion_weighted_sum() {
174 let layer = FusionLayer::new(&[0.7, 0.3], 4);
175 let inputs = vec![
176 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, ];
179 let out = layer.forward(&inputs);
180 assert_eq!(out.len(), 4);
181 assert!((out[0] - 0.7).abs() < 1e-10);
182 }
183
184 #[test]
185 fn fusion_equal_weights() {
186 let layer = FusionLayer::new(&[1.0, 1.0], 2);
187 let inputs = vec![0.6, 0.4, 0.2, 0.8];
188 let out = layer.forward(&inputs);
189 assert!((out[0] - 0.4).abs() < 1e-10);
190 assert!((out[1] - 0.6).abs() < 1e-10);
191 }
192
193 #[test]
194 fn memristive_forward() {
195 let layer = MemristiveLayer::new(4, 2, 256, 42, 0.05, 0.01);
196 let out = layer.forward(&[0.5, 0.5, 0.5, 0.5], 99).unwrap();
197 assert_eq!(out.len(), 2);
198 }
199
200 #[test]
201 fn learning_layer_fires() {
202 let mut layer = LearningLayer::new(4, 2, 0.01, 42);
203 let spikes = vec![true, true, true, true];
204 let out = layer.step(&spikes, 0.5);
205 assert_eq!(out.len(), 2);
206 }
207
208 #[test]
209 fn learning_layer_weights_change() {
210 let mut layer = LearningLayer::new(4, 2, 0.1, 42);
211 let initial = layer.weights.clone();
212 for _ in 0..50 {
213 layer.step(&[true, true, false, false], 0.3);
214 }
215 assert_ne!(layer.weights, initial);
216 }
217}