1use rayon::prelude::*;
16
17pub type SpikeToken = (usize, f64);
19
20pub fn tokenise_spikes(trains: &[&[i32]], dt: f64) -> Vec<SpikeToken> {
25 let mut tokens: Vec<SpikeToken> = trains
26 .par_iter()
27 .enumerate()
28 .flat_map_iter(|(uid, train)| {
29 train
30 .iter()
31 .enumerate()
32 .filter(|(_, &v)| v != 0)
33 .map(move |(idx, _)| (uid, idx as f64 * dt))
34 })
35 .collect();
36 tokens.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
37 tokens
38}
39
40pub fn sinusoidal_position_encode(timestamps: &[f64], d_model: usize) -> Vec<f64> {
47 let n = timestamps.len();
48 let mut pe = vec![0.0_f64; n * d_model];
49 let half_d = d_model / 2 + d_model % 2;
50 let divisors: Vec<f64> = (0..half_d)
51 .map(|i| 10000.0_f64.powf(2.0 * i as f64 / d_model as f64))
52 .collect();
53
54 pe.par_chunks_mut(d_model)
55 .enumerate()
56 .for_each(|(row, pe_row)| {
57 let t = timestamps[row];
58 for (k, div) in divisors.iter().enumerate() {
59 let col_sin = 2 * k;
60 let col_cos = 2 * k + 1;
61 let angle = t / div;
62 pe_row[col_sin] = angle.sin();
63 if col_cos < d_model {
64 pe_row[col_cos] = angle.cos();
65 }
66 }
67 });
68 pe
69}
70
71pub fn scaled_dot_product_attention(
78 queries: &[f64],
79 keys: &[f64],
80 values: &[f64],
81 nq: usize,
82 nk: usize,
83 d: usize,
84) -> Vec<f64> {
85 let inv_sqrt_d = 1.0 / (d as f64).sqrt();
86 let mut output = vec![0.0_f64; nq * d];
87
88 output
89 .par_chunks_mut(d)
90 .enumerate()
91 .for_each(|(i, out_row)| {
92 let q_row = &queries[i * d..(i + 1) * d];
93 let mut scores = vec![0.0_f64; nk];
95 let mut max_score = f64::NEG_INFINITY;
96 for j in 0..nk {
97 let k_row = &keys[j * d..(j + 1) * d];
98 let mut dot = 0.0;
99 for f in 0..d {
100 dot += q_row[f] * k_row[f];
101 }
102 scores[j] = dot * inv_sqrt_d;
103 if scores[j] > max_score {
104 max_score = scores[j];
105 }
106 }
107 let mut sum_exp = 0.0;
109 for s in &mut scores {
110 *s = (*s - max_score).exp();
111 sum_exp += *s;
112 }
113 let inv_sum = 1.0 / (sum_exp + 1e-30);
114 for s in &mut scores {
115 *s *= inv_sum;
116 }
117 for j in 0..nk {
119 let w = scores[j];
120 let v_row = &values[j * d..(j + 1) * d];
121 for f in 0..d {
122 out_row[f] += w * v_row[f];
123 }
124 }
125 });
126 output
127}
128
129pub fn gaussian_attention(
135 queries: &[f64],
136 keys: &[f64],
137 values: &[f64],
138 nq: usize,
139 nk: usize,
140 d: usize,
141 sigma: f64,
142) -> Vec<f64> {
143 let inv_2sigma2 = 1.0 / (2.0 * sigma * sigma);
144 let mut output = vec![0.0_f64; nq * d];
145
146 output
147 .par_chunks_mut(d)
148 .enumerate()
149 .for_each(|(i, out_row)| {
150 let q_row = &queries[i * d..(i + 1) * d];
151 let mut log_weights = vec![0.0_f64; nk];
152 let mut max_lw = f64::NEG_INFINITY;
153 for j in 0..nk {
154 let k_row = &keys[j * d..(j + 1) * d];
155 let mut dist_sq = 0.0;
156 for f in 0..d {
157 let diff = q_row[f] - k_row[f];
158 dist_sq += diff * diff;
159 }
160 log_weights[j] = -dist_sq * inv_2sigma2;
161 if log_weights[j] > max_lw {
162 max_lw = log_weights[j];
163 }
164 }
165 let mut sum_exp = 0.0;
166 for lw in &mut log_weights {
167 *lw = (*lw - max_lw).exp();
168 sum_exp += *lw;
169 }
170 let inv_sum = 1.0 / (sum_exp + 1e-30);
171 for j in 0..nk {
172 let w = log_weights[j] * inv_sum;
173 let v_row = &values[j * d..(j + 1) * d];
174 for f in 0..d {
175 out_row[f] += w * v_row[f];
176 }
177 }
178 });
179 output
180}
181
182pub fn ssm_step_diagonal(
196 a_bar_re: &[f64],
197 a_bar_im: &[f64],
198 b_bar_re: &[f64],
199 b_bar_im: &[f64],
200 c_re: &[f64],
201 c_im: &[f64],
202 d_mat: &[f64],
203 h_re: &mut [f64],
204 h_im: &mut [f64],
205 x: &[f64],
206 d_state: usize,
207 d_model: usize,
208) -> Vec<f64> {
209 for s in 0..d_state {
211 let new_re = a_bar_re[s] * h_re[s] - a_bar_im[s] * h_im[s];
213 let new_im = a_bar_re[s] * h_im[s] + a_bar_im[s] * h_re[s];
214 let mut bx_re = 0.0;
216 let mut bx_im = 0.0;
217 for m in 0..d_model {
218 bx_re += b_bar_re[s * d_model + m] * x[m];
219 bx_im += b_bar_im[s * d_model + m] * x[m];
220 }
221 h_re[s] = new_re + bx_re;
222 h_im[s] = new_im + bx_im;
223 }
224
225 let mut y = vec![0.0_f64; d_model];
227 for m in 0..d_model {
228 let mut ch_re = 0.0;
229 for s in 0..d_state {
230 ch_re += c_re[m * d_state + s] * h_re[s] - c_im[m * d_state + s] * h_im[s];
232 }
233 let mut dx = 0.0;
234 for m2 in 0..d_model {
235 dx += d_mat[m * d_model + m2] * x[m2];
236 }
237 y[m] = ch_re + dx;
238 }
239 y
240}
241
242pub fn infonce_loss(
249 anchors: &[f64],
250 positives: &[f64],
251 n: usize,
252 d: usize,
253 temperature: f64,
254) -> f64 {
255 if n == 0 || d == 0 {
256 return 0.0;
257 }
258 let inv_tau = 1.0 / temperature;
259
260 let norm = |v: &[f64]| -> Vec<f64> {
262 let mut out = v.to_vec();
263 for i in 0..n {
264 let row = &mut out[i * d..(i + 1) * d];
265 let nrm: f64 = row.iter().map(|x| x * x).sum::<f64>().sqrt() + 1e-30;
266 for x in row.iter_mut() {
267 *x /= nrm;
268 }
269 }
270 out
271 };
272
273 let a_norm = norm(anchors);
274 let p_norm = norm(positives);
275
276 let total_loss: f64 = (0..n)
277 .into_par_iter()
278 .map(|i| {
279 let a_row = &a_norm[i * d..(i + 1) * d];
280 let p_row = &p_norm[i * d..(i + 1) * d];
282 let pos_sim: f64 = a_row.iter().zip(p_row).map(|(a, p)| a * p).sum();
283
284 let mut max_sim = f64::NEG_INFINITY;
286 let mut sims = vec![0.0_f64; n];
287 for j in 0..n {
288 let pj = &p_norm[j * d..(j + 1) * d];
289 let sim: f64 = a_row.iter().zip(pj).map(|(a, p)| a * p).sum();
290 sims[j] = sim * inv_tau;
291 if sims[j] > max_sim {
292 max_sim = sims[j];
293 }
294 }
295 let sum_exp: f64 = sims.iter().map(|s| (s - max_sim).exp()).sum();
296 let log_softmax = pos_sim * inv_tau - max_sim - sum_exp.ln();
297 -log_softmax
298 })
299 .sum();
300
301 total_loss / n as f64
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_tokenise_empty() {
310 let tokens = tokenise_spikes(&[], 1.0);
311 assert!(tokens.is_empty());
312 }
313
314 #[test]
315 fn test_tokenise_single() {
316 let train = vec![0, 0, 1, 0, 0];
317 let tokens = tokenise_spikes(&[&train], 0.5);
318 assert_eq!(tokens.len(), 1);
319 assert_eq!(tokens[0].0, 0);
320 assert!((tokens[0].1 - 1.0).abs() < 1e-10);
321 }
322
323 #[test]
324 fn test_tokenise_sorted() {
325 let t0 = vec![0, 0, 0, 0, 1]; let t1 = vec![0, 1, 0, 0, 0]; let tokens = tokenise_spikes(&[&t0, &t1], 1.0);
328 assert_eq!(tokens.len(), 2);
329 assert!(tokens[0].1 <= tokens[1].1);
330 }
331
332 #[test]
333 fn test_sinusoidal_pe_shape() {
334 let ts = vec![0.0, 1.0, 2.0];
335 let pe = sinusoidal_position_encode(&ts, 8);
336 assert_eq!(pe.len(), 3 * 8);
337 }
338
339 #[test]
340 fn test_sinusoidal_pe_zero() {
341 let pe = sinusoidal_position_encode(&[0.0], 4);
342 assert!((pe[0] - 0.0).abs() < 1e-10); assert!((pe[1] - 1.0).abs() < 1e-10); }
345
346 #[test]
347 fn test_attention_shape() {
348 let q = vec![1.0, 0.0, 0.0, 1.0]; let k = vec![1.0, 0.0, 0.0, 1.0, 0.5, 0.5]; let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let out = scaled_dot_product_attention(&q, &k, &v, 2, 3, 2);
352 assert_eq!(out.len(), 4);
353 }
354
355 #[test]
356 fn test_gaussian_attention_concentrates() {
357 let q = vec![0.0, 0.0];
359 let k = vec![0.0, 0.0, 100.0, 100.0];
360 let v = vec![1.0, 0.0, 0.0, 1.0];
361 let out = gaussian_attention(&q, &k, &v, 1, 2, 2, 0.01);
362 assert!((out[0] - 1.0).abs() < 1e-3);
364 assert!((out[1] - 0.0).abs() < 1e-3);
365 }
366
367 #[test]
368 fn test_ssm_step_output_size() {
369 let d_state = 2;
370 let d_model = 3;
371 let a_re = vec![0.9, 0.8];
372 let a_im = vec![0.1, 0.2];
373 let b_re = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]; let b_im = vec![0.0; 6];
375 let c_re = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]; let c_im = vec![0.0; 6];
377 let d_mat = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]; let mut h_re = vec![0.0; 2];
379 let mut h_im = vec![0.0; 2];
380 let x = vec![1.0, 0.0, 0.0];
381 let y = ssm_step_diagonal(
382 &a_re, &a_im, &b_re, &b_im, &c_re, &c_im, &d_mat, &mut h_re, &mut h_im, &x, d_state,
383 d_model,
384 );
385 assert_eq!(y.len(), 3);
386 }
387
388 #[test]
389 fn test_ssm_state_update() {
390 let d_state = 1;
391 let d_model = 1;
392 let mut h_re = vec![0.0];
393 let mut h_im = vec![0.0];
394 ssm_step_diagonal(
395 &[0.9],
396 &[0.0],
397 &[1.0],
398 &[0.0],
399 &[1.0],
400 &[0.0],
401 &[0.0],
402 &mut h_re,
403 &mut h_im,
404 &[1.0],
405 d_state,
406 d_model,
407 );
408 assert!((h_re[0] - 1.0).abs() < 1e-10);
410 }
411
412 #[test]
413 fn test_infonce_identical_pairs() {
414 let d = 4;
415 let n = 3;
416 let data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
417 let loss = infonce_loss(&data, &data, n, d, 1.0);
418 assert!(loss >= 0.0);
421 }
422
423 #[test]
424 fn test_infonce_temperature() {
425 let d = 2;
426 let n = 4;
427 let a = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0];
428 let p = a.clone();
429 let loss_cold = infonce_loss(&a, &p, n, d, 0.1);
430 let loss_hot = infonce_loss(&a, &p, n, d, 10.0);
431 assert!(loss_cold < loss_hot);
432 }
433}