Skip to main content

sc_neurocore_engine/scpn/
dcls.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// Copyright (C) 2020-2026 Miroslav Sotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore - Delay-coded learnable-spike Q8.8 reference
7
8//! Bit-true Q8.8 reference for the delay-coded learnable-spike tent kernel.
9
10use std::error::Error;
11use std::fmt;
12
13/// Default fractional precision for Q8.8 DCLS contracts.
14pub const DEFAULT_FRACTION: u32 = 8;
15
16/// Default fixed-point data width for DCLS weights and outputs.
17pub const DEFAULT_DATA_WIDTH: u32 = 16;
18
19/// Default accumulator width for Q16.16 DCLS accumulators.
20pub const DEFAULT_ACCUMULATOR_WIDTH: u32 = 32;
21
22const Q88_ONE: i64 = 1_i64 << DEFAULT_FRACTION;
23const I32_MAX_AS_I64: i64 = i32::MAX as i64;
24const I32_MIN_AS_I64: i64 = i32::MIN as i64;
25const I16_MAX_Q16_16: i64 = (i16::MAX as i64) << DEFAULT_FRACTION;
26const I16_MIN_Q16_16: i64 = (i16::MIN as i64) << DEFAULT_FRACTION;
27
28/// Configuration for the bit-true DCLS Q8.8 forward pass.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct DclsLayerConfig {
31    pub data_width: u32,
32    pub fraction: u32,
33    pub accumulator_width: u32,
34}
35
36impl Default for DclsLayerConfig {
37    fn default() -> Self {
38        Self {
39            data_width: DEFAULT_DATA_WIDTH,
40            fraction: DEFAULT_FRACTION,
41            accumulator_width: DEFAULT_ACCUMULATOR_WIDTH,
42        }
43    }
44}
45
46/// Saturating result of one DCLS tent-kernel contraction.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub struct DclsForwardResult {
49    /// Saturated Q8.8 output.
50    pub output_q88: i16,
51    /// Saturated Q16.16 accumulator.
52    pub accumulator_q16_16: i32,
53    /// True when accumulator or output saturation occurred.
54    pub overflow: bool,
55    /// Number of non-zero spike taps consumed by the contraction.
56    pub active_tap_count: usize,
57    /// Largest Q8.8 tent gate applied to an active spike.
58    pub max_gate_q88: i16,
59}
60
61/// DCLS arithmetic contract errors.
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum DclsError {
64    UnsupportedFormat {
65        data_width: u32,
66        fraction: u32,
67        accumulator_width: u32,
68    },
69    EmptyTaps,
70    MismatchedLengths {
71        spikes: usize,
72        weights: usize,
73    },
74    InvalidSigma {
75        sigma_q88: i16,
76    },
77    TapIndexOverflow {
78        tap_index: usize,
79    },
80    EmptyChannels,
81    ChannelLengthMismatch {
82        centres: usize,
83        sigmas: usize,
84    },
85}
86
87impl fmt::Display for DclsError {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            Self::UnsupportedFormat {
91                data_width,
92                fraction,
93                accumulator_width,
94            } => write!(
95                f,
96                "unsupported DCLS format data_width={data_width}, fraction={fraction}, accumulator_width={accumulator_width}"
97            ),
98            Self::EmptyTaps => write!(f, "DCLS forward pass requires at least one tap"),
99            Self::MismatchedLengths { spikes, weights } => write!(
100                f,
101                "DCLS spike/weight length mismatch: spikes={spikes}, weights={weights}"
102            ),
103            Self::InvalidSigma { sigma_q88 } => {
104                write!(f, "DCLS tent sigma must be positive, got {sigma_q88}")
105            }
106            Self::TapIndexOverflow { tap_index } => {
107                write!(f, "DCLS tap index {tap_index} cannot be represented as Q8.8")
108            }
109            Self::EmptyChannels => {
110                write!(f, "DCLS batch requires at least one output channel")
111            }
112            Self::ChannelLengthMismatch { centres, sigmas } => write!(
113                f,
114                "DCLS centre/sigma length mismatch: centres={centres}, sigmas={sigmas}"
115            ),
116        }
117    }
118}
119
120impl Error for DclsError {}
121
122impl DclsLayerConfig {
123    fn validate(self) -> Result<(), DclsError> {
124        if self.data_width != DEFAULT_DATA_WIDTH
125            || self.fraction != DEFAULT_FRACTION
126            || self.accumulator_width != DEFAULT_ACCUMULATOR_WIDTH
127        {
128            return Err(DclsError::UnsupportedFormat {
129                data_width: self.data_width,
130                fraction: self.fraction,
131                accumulator_width: self.accumulator_width,
132            });
133        }
134        Ok(())
135    }
136}
137
138/// Return the Q8.8 triangular tent gate for a tap delay index.
139pub fn tent_gate_q88(tap_index: usize, centre_q88: i16, sigma_q88: i16) -> Result<i16, DclsError> {
140    if sigma_q88 <= 0 {
141        return Err(DclsError::InvalidSigma { sigma_q88 });
142    }
143    let delay_q88 = i64::try_from(tap_index)
144        .ok()
145        .and_then(|index| index.checked_shl(DEFAULT_FRACTION))
146        .ok_or(DclsError::TapIndexOverflow { tap_index })?;
147    let centre = i64::from(centre_q88);
148    let sigma = i64::from(sigma_q88);
149    let distance = (delay_q88 - centre).abs();
150    if distance >= sigma {
151        return Ok(0);
152    }
153    let numerator = sigma - distance;
154    let gate = (numerator << DEFAULT_FRACTION) / sigma;
155    Ok(gate.clamp(0, Q88_ONE) as i16)
156}
157
158/// Execute the DCLS Q8.8 tent contraction with a Q16.16 accumulator.
159pub fn dcls_max_forward_q88(
160    spikes: &[u8],
161    weights_q88: &[i16],
162    centre_q88: i16,
163    sigma_q88: i16,
164) -> Result<DclsForwardResult, DclsError> {
165    dcls_max_forward_q88_with_config(
166        spikes,
167        weights_q88,
168        centre_q88,
169        sigma_q88,
170        DclsLayerConfig::default(),
171    )
172}
173
174/// Execute the DCLS Q8.8 tent contraction with an explicit format contract.
175pub fn dcls_max_forward_q88_with_config(
176    spikes: &[u8],
177    weights_q88: &[i16],
178    centre_q88: i16,
179    sigma_q88: i16,
180    config: DclsLayerConfig,
181) -> Result<DclsForwardResult, DclsError> {
182    config.validate()?;
183    if spikes.is_empty() {
184        return Err(DclsError::EmptyTaps);
185    }
186    if spikes.len() != weights_q88.len() {
187        return Err(DclsError::MismatchedLengths {
188            spikes: spikes.len(),
189            weights: weights_q88.len(),
190        });
191    }
192    if sigma_q88 <= 0 {
193        return Err(DclsError::InvalidSigma { sigma_q88 });
194    }
195
196    let mut accumulator = 0_i64;
197    let mut active_tap_count = 0_usize;
198    let mut max_gate_q88 = 0_i16;
199    for (tap_index, (&spike, &weight)) in spikes.iter().zip(weights_q88.iter()).enumerate() {
200        if spike == 0 {
201            continue;
202        }
203        active_tap_count += 1;
204        let gate = tent_gate_q88(tap_index, centre_q88, sigma_q88)?;
205        max_gate_q88 = max_gate_q88.max(gate);
206        accumulator += i64::from(weight) * i64::from(gate);
207    }
208
209    let (accumulator_q16_16, accumulator_overflow) = saturate_i32(accumulator);
210    let (output_q88, output_overflow) = saturate_q88_output(accumulator);
211    Ok(DclsForwardResult {
212        output_q88,
213        accumulator_q16_16,
214        overflow: accumulator_overflow || output_overflow,
215        active_tap_count,
216        max_gate_q88,
217    })
218}
219
220/// Per-channel results of a batched DCLS-max tent contraction.
221///
222/// Each vector is indexed by output channel; all have length equal to the
223/// number of `(centre, sigma)` pairs supplied to
224/// [`dcls_max_forward_batch_q88`].
225#[derive(Debug, Clone, PartialEq, Eq)]
226pub struct DclsBatchResult {
227    /// Saturated Q8.8 outputs.
228    pub outputs_q88: Vec<i16>,
229    /// Saturated Q16.16 accumulators.
230    pub accumulators_q16_16: Vec<i32>,
231    /// Saturation flags.
232    pub overflow: Vec<bool>,
233    /// Active spike-tap counts per channel.
234    pub active_tap_counts: Vec<usize>,
235    /// Largest applied tent gate per channel.
236    pub max_gates_q88: Vec<i16>,
237}
238
239/// Execute the DCLS-max tent contraction across many output channels.
240///
241/// Channel `c` contracts its own `n_taps`-long spike/weight row through a tent
242/// kernel with channel-specific learnable `centre`/`sigma`. The spike and
243/// weight buffers are row-major: channel `c` occupies `[c * n_taps, (c + 1) *
244/// n_taps)`. The per-channel result is bit-identical to
245/// [`dcls_max_forward_q88`].
246pub fn dcls_max_forward_batch_q88(
247    spikes: &[u8],
248    weights_q88: &[i16],
249    centres_q88: &[i16],
250    sigmas_q88: &[i16],
251    n_taps: usize,
252) -> Result<DclsBatchResult, DclsError> {
253    if n_taps == 0 {
254        return Err(DclsError::EmptyTaps);
255    }
256    let n_channels = centres_q88.len();
257    if n_channels == 0 {
258        return Err(DclsError::EmptyChannels);
259    }
260    if sigmas_q88.len() != n_channels {
261        return Err(DclsError::ChannelLengthMismatch {
262            centres: n_channels,
263            sigmas: sigmas_q88.len(),
264        });
265    }
266    let expected = n_channels * n_taps;
267    if spikes.len() != expected || weights_q88.len() != expected {
268        return Err(DclsError::MismatchedLengths {
269            spikes: spikes.len(),
270            weights: weights_q88.len(),
271        });
272    }
273
274    let mut outputs_q88 = Vec::with_capacity(n_channels);
275    let mut accumulators_q16_16 = Vec::with_capacity(n_channels);
276    let mut overflow = Vec::with_capacity(n_channels);
277    let mut active_tap_counts = Vec::with_capacity(n_channels);
278    let mut max_gates_q88 = Vec::with_capacity(n_channels);
279    for channel in 0..n_channels {
280        let base = channel * n_taps;
281        let result = dcls_max_forward_q88(
282            &spikes[base..base + n_taps],
283            &weights_q88[base..base + n_taps],
284            centres_q88[channel],
285            sigmas_q88[channel],
286        )?;
287        outputs_q88.push(result.output_q88);
288        accumulators_q16_16.push(result.accumulator_q16_16);
289        overflow.push(result.overflow);
290        active_tap_counts.push(result.active_tap_count);
291        max_gates_q88.push(result.max_gate_q88);
292    }
293    Ok(DclsBatchResult {
294        outputs_q88,
295        accumulators_q16_16,
296        overflow,
297        active_tap_counts,
298        max_gates_q88,
299    })
300}
301
302fn saturate_i32(value: i64) -> (i32, bool) {
303    if value > I32_MAX_AS_I64 {
304        (i32::MAX, true)
305    } else if value < I32_MIN_AS_I64 {
306        (i32::MIN, true)
307    } else {
308        (value as i32, false)
309    }
310}
311
312fn saturate_q88_output(accumulator_q16_16: i64) -> (i16, bool) {
313    if accumulator_q16_16 > I16_MAX_Q16_16 {
314        (i16::MAX, true)
315    } else if accumulator_q16_16 < I16_MIN_Q16_16 {
316        (i16::MIN, true)
317    } else {
318        ((accumulator_q16_16 >> DEFAULT_FRACTION) as i16, false)
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn tent_gate_matches_special_cases() {
328        assert_eq!(tent_gate_q88(1, 256, 512).unwrap(), 256);
329        assert_eq!(tent_gate_q88(0, 256, 512).unwrap(), 128);
330        assert_eq!(tent_gate_q88(3, 256, 512).unwrap(), 0);
331    }
332
333    #[test]
334    fn forward_matches_hand_computed_q16_16_accumulator() {
335        let result = dcls_max_forward_q88(&[1, 1, 1], &[256, 128, -64], 256, 512).unwrap();
336        assert_eq!(result.accumulator_q16_16, 57_344);
337        assert_eq!(result.output_q88, 224);
338        assert_eq!(result.active_tap_count, 3);
339        assert_eq!(result.max_gate_q88, 256);
340        assert!(!result.overflow);
341    }
342
343    #[test]
344    fn zero_spike_taps_do_not_contribute() {
345        let result = dcls_max_forward_q88(&[0, 1, 0], &[256, 128, -64], 256, 512).unwrap();
346        assert_eq!(result.accumulator_q16_16, 32_768);
347        assert_eq!(result.output_q88, 128);
348        assert_eq!(result.active_tap_count, 1);
349    }
350
351    #[test]
352    fn invalid_sigma_fails_closed() {
353        assert_eq!(
354            dcls_max_forward_q88(&[1], &[256], 0, 0).unwrap_err(),
355            DclsError::InvalidSigma { sigma_q88: 0 }
356        );
357    }
358
359    #[test]
360    fn mismatched_taps_fail_closed() {
361        assert_eq!(
362            dcls_max_forward_q88(&[1, 1], &[256], 0, 256).unwrap_err(),
363            DclsError::MismatchedLengths {
364                spikes: 2,
365                weights: 1
366            }
367        );
368    }
369
370    #[test]
371    fn saturating_output_sets_overflow() {
372        let spikes = vec![1_u8; 1024];
373        let weights = vec![i16::MAX; 1024];
374        let result = dcls_max_forward_q88(&spikes, &weights, 0, i16::MAX).unwrap();
375        assert_eq!(result.output_q88, i16::MAX);
376        assert!(result.overflow);
377    }
378
379    #[test]
380    fn batch_matches_per_channel_single_forward() {
381        let spikes = [1_u8, 1, 1, 0, 1, 0];
382        let weights = [256_i16, 128, -64, 256, 128, -64];
383        let centres = [256_i16, 256];
384        let sigmas = [512_i16, 512];
385        let batch = dcls_max_forward_batch_q88(&spikes, &weights, &centres, &sigmas, 3).unwrap();
386        assert_eq!(batch.outputs_q88, vec![224, 128]);
387        assert_eq!(batch.accumulators_q16_16, vec![57_344, 32_768]);
388        assert_eq!(batch.active_tap_counts, vec![3, 1]);
389        assert_eq!(batch.max_gates_q88, vec![256, 256]);
390        assert_eq!(batch.overflow, vec![false, false]);
391        // Every channel equals the standalone single contraction.
392        for channel in 0..2 {
393            let base = channel * 3;
394            let single = dcls_max_forward_q88(
395                &spikes[base..base + 3],
396                &weights[base..base + 3],
397                centres[channel],
398                sigmas[channel],
399            )
400            .unwrap();
401            assert_eq!(batch.outputs_q88[channel], single.output_q88);
402            assert_eq!(
403                batch.accumulators_q16_16[channel],
404                single.accumulator_q16_16
405            );
406        }
407    }
408
409    #[test]
410    fn batch_rejects_zero_taps() {
411        assert_eq!(
412            dcls_max_forward_batch_q88(&[], &[], &[256], &[512], 0).unwrap_err(),
413            DclsError::EmptyTaps
414        );
415    }
416
417    #[test]
418    fn batch_rejects_empty_channels() {
419        assert_eq!(
420            dcls_max_forward_batch_q88(&[], &[], &[], &[], 3).unwrap_err(),
421            DclsError::EmptyChannels
422        );
423    }
424
425    #[test]
426    fn batch_rejects_centre_sigma_mismatch() {
427        assert_eq!(
428            dcls_max_forward_batch_q88(&[1, 1], &[256, 128], &[256, 0], &[512], 1).unwrap_err(),
429            DclsError::ChannelLengthMismatch {
430                centres: 2,
431                sigmas: 1
432            }
433        );
434    }
435
436    #[test]
437    fn batch_rejects_flat_length_mismatch() {
438        assert_eq!(
439            dcls_max_forward_batch_q88(&[1, 1, 1], &[256, 128, -64], &[256], &[512], 2)
440                .unwrap_err(),
441            DclsError::MismatchedLengths {
442                spikes: 3,
443                weights: 3
444            }
445        );
446    }
447
448    #[test]
449    fn batch_propagates_invalid_sigma() {
450        assert_eq!(
451            dcls_max_forward_batch_q88(&[1], &[256], &[0], &[0], 1).unwrap_err(),
452            DclsError::InvalidSigma { sigma_q88: 0 }
453        );
454    }
455}