1use std::error::Error;
11use std::fmt;
12
13pub const DEFAULT_FRACTION: u32 = 8;
15
16pub const DEFAULT_DATA_WIDTH: u32 = 16;
18
19pub const DEFAULT_ACCUMULATOR_WIDTH: u32 = 32;
21
22const Q88_ONE: i64 = 1_i64 << DEFAULT_FRACTION;
23const I32_MAX_AS_I64: i64 = i32::MAX as i64;
24const I32_MIN_AS_I64: i64 = i32::MIN as i64;
25const I16_MAX_Q16_16: i64 = (i16::MAX as i64) << DEFAULT_FRACTION;
26const I16_MIN_Q16_16: i64 = (i16::MIN as i64) << DEFAULT_FRACTION;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct DclsLayerConfig {
31 pub data_width: u32,
32 pub fraction: u32,
33 pub accumulator_width: u32,
34}
35
36impl Default for DclsLayerConfig {
37 fn default() -> Self {
38 Self {
39 data_width: DEFAULT_DATA_WIDTH,
40 fraction: DEFAULT_FRACTION,
41 accumulator_width: DEFAULT_ACCUMULATOR_WIDTH,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub struct DclsForwardResult {
49 pub output_q88: i16,
51 pub accumulator_q16_16: i32,
53 pub overflow: bool,
55 pub active_tap_count: usize,
57 pub max_gate_q88: i16,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum DclsError {
64 UnsupportedFormat {
65 data_width: u32,
66 fraction: u32,
67 accumulator_width: u32,
68 },
69 EmptyTaps,
70 MismatchedLengths {
71 spikes: usize,
72 weights: usize,
73 },
74 InvalidSigma {
75 sigma_q88: i16,
76 },
77 TapIndexOverflow {
78 tap_index: usize,
79 },
80 EmptyChannels,
81 ChannelLengthMismatch {
82 centres: usize,
83 sigmas: usize,
84 },
85}
86
87impl fmt::Display for DclsError {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 Self::UnsupportedFormat {
91 data_width,
92 fraction,
93 accumulator_width,
94 } => write!(
95 f,
96 "unsupported DCLS format data_width={data_width}, fraction={fraction}, accumulator_width={accumulator_width}"
97 ),
98 Self::EmptyTaps => write!(f, "DCLS forward pass requires at least one tap"),
99 Self::MismatchedLengths { spikes, weights } => write!(
100 f,
101 "DCLS spike/weight length mismatch: spikes={spikes}, weights={weights}"
102 ),
103 Self::InvalidSigma { sigma_q88 } => {
104 write!(f, "DCLS tent sigma must be positive, got {sigma_q88}")
105 }
106 Self::TapIndexOverflow { tap_index } => {
107 write!(f, "DCLS tap index {tap_index} cannot be represented as Q8.8")
108 }
109 Self::EmptyChannels => {
110 write!(f, "DCLS batch requires at least one output channel")
111 }
112 Self::ChannelLengthMismatch { centres, sigmas } => write!(
113 f,
114 "DCLS centre/sigma length mismatch: centres={centres}, sigmas={sigmas}"
115 ),
116 }
117 }
118}
119
120impl Error for DclsError {}
121
122impl DclsLayerConfig {
123 fn validate(self) -> Result<(), DclsError> {
124 if self.data_width != DEFAULT_DATA_WIDTH
125 || self.fraction != DEFAULT_FRACTION
126 || self.accumulator_width != DEFAULT_ACCUMULATOR_WIDTH
127 {
128 return Err(DclsError::UnsupportedFormat {
129 data_width: self.data_width,
130 fraction: self.fraction,
131 accumulator_width: self.accumulator_width,
132 });
133 }
134 Ok(())
135 }
136}
137
138pub fn tent_gate_q88(tap_index: usize, centre_q88: i16, sigma_q88: i16) -> Result<i16, DclsError> {
140 if sigma_q88 <= 0 {
141 return Err(DclsError::InvalidSigma { sigma_q88 });
142 }
143 let delay_q88 = i64::try_from(tap_index)
144 .ok()
145 .and_then(|index| index.checked_shl(DEFAULT_FRACTION))
146 .ok_or(DclsError::TapIndexOverflow { tap_index })?;
147 let centre = i64::from(centre_q88);
148 let sigma = i64::from(sigma_q88);
149 let distance = (delay_q88 - centre).abs();
150 if distance >= sigma {
151 return Ok(0);
152 }
153 let numerator = sigma - distance;
154 let gate = (numerator << DEFAULT_FRACTION) / sigma;
155 Ok(gate.clamp(0, Q88_ONE) as i16)
156}
157
158pub fn dcls_max_forward_q88(
160 spikes: &[u8],
161 weights_q88: &[i16],
162 centre_q88: i16,
163 sigma_q88: i16,
164) -> Result<DclsForwardResult, DclsError> {
165 dcls_max_forward_q88_with_config(
166 spikes,
167 weights_q88,
168 centre_q88,
169 sigma_q88,
170 DclsLayerConfig::default(),
171 )
172}
173
174pub fn dcls_max_forward_q88_with_config(
176 spikes: &[u8],
177 weights_q88: &[i16],
178 centre_q88: i16,
179 sigma_q88: i16,
180 config: DclsLayerConfig,
181) -> Result<DclsForwardResult, DclsError> {
182 config.validate()?;
183 if spikes.is_empty() {
184 return Err(DclsError::EmptyTaps);
185 }
186 if spikes.len() != weights_q88.len() {
187 return Err(DclsError::MismatchedLengths {
188 spikes: spikes.len(),
189 weights: weights_q88.len(),
190 });
191 }
192 if sigma_q88 <= 0 {
193 return Err(DclsError::InvalidSigma { sigma_q88 });
194 }
195
196 let mut accumulator = 0_i64;
197 let mut active_tap_count = 0_usize;
198 let mut max_gate_q88 = 0_i16;
199 for (tap_index, (&spike, &weight)) in spikes.iter().zip(weights_q88.iter()).enumerate() {
200 if spike == 0 {
201 continue;
202 }
203 active_tap_count += 1;
204 let gate = tent_gate_q88(tap_index, centre_q88, sigma_q88)?;
205 max_gate_q88 = max_gate_q88.max(gate);
206 accumulator += i64::from(weight) * i64::from(gate);
207 }
208
209 let (accumulator_q16_16, accumulator_overflow) = saturate_i32(accumulator);
210 let (output_q88, output_overflow) = saturate_q88_output(accumulator);
211 Ok(DclsForwardResult {
212 output_q88,
213 accumulator_q16_16,
214 overflow: accumulator_overflow || output_overflow,
215 active_tap_count,
216 max_gate_q88,
217 })
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
226pub struct DclsBatchResult {
227 pub outputs_q88: Vec<i16>,
229 pub accumulators_q16_16: Vec<i32>,
231 pub overflow: Vec<bool>,
233 pub active_tap_counts: Vec<usize>,
235 pub max_gates_q88: Vec<i16>,
237}
238
239pub fn dcls_max_forward_batch_q88(
247 spikes: &[u8],
248 weights_q88: &[i16],
249 centres_q88: &[i16],
250 sigmas_q88: &[i16],
251 n_taps: usize,
252) -> Result<DclsBatchResult, DclsError> {
253 if n_taps == 0 {
254 return Err(DclsError::EmptyTaps);
255 }
256 let n_channels = centres_q88.len();
257 if n_channels == 0 {
258 return Err(DclsError::EmptyChannels);
259 }
260 if sigmas_q88.len() != n_channels {
261 return Err(DclsError::ChannelLengthMismatch {
262 centres: n_channels,
263 sigmas: sigmas_q88.len(),
264 });
265 }
266 let expected = n_channels * n_taps;
267 if spikes.len() != expected || weights_q88.len() != expected {
268 return Err(DclsError::MismatchedLengths {
269 spikes: spikes.len(),
270 weights: weights_q88.len(),
271 });
272 }
273
274 let mut outputs_q88 = Vec::with_capacity(n_channels);
275 let mut accumulators_q16_16 = Vec::with_capacity(n_channels);
276 let mut overflow = Vec::with_capacity(n_channels);
277 let mut active_tap_counts = Vec::with_capacity(n_channels);
278 let mut max_gates_q88 = Vec::with_capacity(n_channels);
279 for channel in 0..n_channels {
280 let base = channel * n_taps;
281 let result = dcls_max_forward_q88(
282 &spikes[base..base + n_taps],
283 &weights_q88[base..base + n_taps],
284 centres_q88[channel],
285 sigmas_q88[channel],
286 )?;
287 outputs_q88.push(result.output_q88);
288 accumulators_q16_16.push(result.accumulator_q16_16);
289 overflow.push(result.overflow);
290 active_tap_counts.push(result.active_tap_count);
291 max_gates_q88.push(result.max_gate_q88);
292 }
293 Ok(DclsBatchResult {
294 outputs_q88,
295 accumulators_q16_16,
296 overflow,
297 active_tap_counts,
298 max_gates_q88,
299 })
300}
301
302fn saturate_i32(value: i64) -> (i32, bool) {
303 if value > I32_MAX_AS_I64 {
304 (i32::MAX, true)
305 } else if value < I32_MIN_AS_I64 {
306 (i32::MIN, true)
307 } else {
308 (value as i32, false)
309 }
310}
311
312fn saturate_q88_output(accumulator_q16_16: i64) -> (i16, bool) {
313 if accumulator_q16_16 > I16_MAX_Q16_16 {
314 (i16::MAX, true)
315 } else if accumulator_q16_16 < I16_MIN_Q16_16 {
316 (i16::MIN, true)
317 } else {
318 ((accumulator_q16_16 >> DEFAULT_FRACTION) as i16, false)
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn tent_gate_matches_special_cases() {
328 assert_eq!(tent_gate_q88(1, 256, 512).unwrap(), 256);
329 assert_eq!(tent_gate_q88(0, 256, 512).unwrap(), 128);
330 assert_eq!(tent_gate_q88(3, 256, 512).unwrap(), 0);
331 }
332
333 #[test]
334 fn forward_matches_hand_computed_q16_16_accumulator() {
335 let result = dcls_max_forward_q88(&[1, 1, 1], &[256, 128, -64], 256, 512).unwrap();
336 assert_eq!(result.accumulator_q16_16, 57_344);
337 assert_eq!(result.output_q88, 224);
338 assert_eq!(result.active_tap_count, 3);
339 assert_eq!(result.max_gate_q88, 256);
340 assert!(!result.overflow);
341 }
342
343 #[test]
344 fn zero_spike_taps_do_not_contribute() {
345 let result = dcls_max_forward_q88(&[0, 1, 0], &[256, 128, -64], 256, 512).unwrap();
346 assert_eq!(result.accumulator_q16_16, 32_768);
347 assert_eq!(result.output_q88, 128);
348 assert_eq!(result.active_tap_count, 1);
349 }
350
351 #[test]
352 fn invalid_sigma_fails_closed() {
353 assert_eq!(
354 dcls_max_forward_q88(&[1], &[256], 0, 0).unwrap_err(),
355 DclsError::InvalidSigma { sigma_q88: 0 }
356 );
357 }
358
359 #[test]
360 fn mismatched_taps_fail_closed() {
361 assert_eq!(
362 dcls_max_forward_q88(&[1, 1], &[256], 0, 256).unwrap_err(),
363 DclsError::MismatchedLengths {
364 spikes: 2,
365 weights: 1
366 }
367 );
368 }
369
370 #[test]
371 fn saturating_output_sets_overflow() {
372 let spikes = vec![1_u8; 1024];
373 let weights = vec![i16::MAX; 1024];
374 let result = dcls_max_forward_q88(&spikes, &weights, 0, i16::MAX).unwrap();
375 assert_eq!(result.output_q88, i16::MAX);
376 assert!(result.overflow);
377 }
378
379 #[test]
380 fn batch_matches_per_channel_single_forward() {
381 let spikes = [1_u8, 1, 1, 0, 1, 0];
382 let weights = [256_i16, 128, -64, 256, 128, -64];
383 let centres = [256_i16, 256];
384 let sigmas = [512_i16, 512];
385 let batch = dcls_max_forward_batch_q88(&spikes, &weights, ¢res, &sigmas, 3).unwrap();
386 assert_eq!(batch.outputs_q88, vec![224, 128]);
387 assert_eq!(batch.accumulators_q16_16, vec![57_344, 32_768]);
388 assert_eq!(batch.active_tap_counts, vec![3, 1]);
389 assert_eq!(batch.max_gates_q88, vec![256, 256]);
390 assert_eq!(batch.overflow, vec![false, false]);
391 for channel in 0..2 {
393 let base = channel * 3;
394 let single = dcls_max_forward_q88(
395 &spikes[base..base + 3],
396 &weights[base..base + 3],
397 centres[channel],
398 sigmas[channel],
399 )
400 .unwrap();
401 assert_eq!(batch.outputs_q88[channel], single.output_q88);
402 assert_eq!(
403 batch.accumulators_q16_16[channel],
404 single.accumulator_q16_16
405 );
406 }
407 }
408
409 #[test]
410 fn batch_rejects_zero_taps() {
411 assert_eq!(
412 dcls_max_forward_batch_q88(&[], &[], &[256], &[512], 0).unwrap_err(),
413 DclsError::EmptyTaps
414 );
415 }
416
417 #[test]
418 fn batch_rejects_empty_channels() {
419 assert_eq!(
420 dcls_max_forward_batch_q88(&[], &[], &[], &[], 3).unwrap_err(),
421 DclsError::EmptyChannels
422 );
423 }
424
425 #[test]
426 fn batch_rejects_centre_sigma_mismatch() {
427 assert_eq!(
428 dcls_max_forward_batch_q88(&[1, 1], &[256, 128], &[256, 0], &[512], 1).unwrap_err(),
429 DclsError::ChannelLengthMismatch {
430 centres: 2,
431 sigmas: 1
432 }
433 );
434 }
435
436 #[test]
437 fn batch_rejects_flat_length_mismatch() {
438 assert_eq!(
439 dcls_max_forward_batch_q88(&[1, 1, 1], &[256, 128, -64], &[256], &[512], 2)
440 .unwrap_err(),
441 DclsError::MismatchedLengths {
442 spikes: 3,
443 weights: 3
444 }
445 );
446 }
447
448 #[test]
449 fn batch_propagates_invalid_sigma() {
450 assert_eq!(
451 dcls_max_forward_batch_q88(&[1], &[256], &[0], &[0], 1).unwrap_err(),
452 DclsError::InvalidSigma { sigma_q88: 0 }
453 );
454 }
455}