Skip to main content

sc_neurocore_engine/ir/
qformat.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Rust Q-format and mixed dense contracts
8
9use std::error::Error;
10use std::fmt;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct QFormat {
14    pub integer_bits: u8,
15    pub fraction_bits: u8,
16}
17
18impl QFormat {
19    pub const fn q8_8() -> Self {
20        Self {
21            integer_bits: 8,
22            fraction_bits: 8,
23        }
24    }
25
26    pub const fn q16_16() -> Self {
27        Self {
28            integer_bits: 16,
29            fraction_bits: 16,
30        }
31    }
32
33    pub fn new(integer_bits: u8, fraction_bits: u8) -> Result<Self, QFormatError> {
34        if integer_bits == 0 {
35            return Err(QFormatError::MissingSignBit);
36        }
37        let total_bits = u16::from(integer_bits) + u16::from(fraction_bits);
38        if total_bits == 0 || total_bits > 63 {
39            return Err(QFormatError::TotalBitsTooWide(total_bits));
40        }
41        Ok(Self {
42            integer_bits,
43            fraction_bits,
44        })
45    }
46
47    pub fn total_bits(self) -> u8 {
48        self.integer_bits + self.fraction_bits
49    }
50
51    pub fn scale(self) -> i128 {
52        1_i128 << self.fraction_bits
53    }
54
55    pub fn min_value(self) -> f64 {
56        -((1_i128 << (self.total_bits() - 1)) as f64) / self.scale() as f64
57    }
58
59    pub fn max_value(self) -> f64 {
60        ((1_i128 << (self.total_bits() - 1)) - 1) as f64 / self.scale() as f64
61    }
62
63    pub fn label(self) -> String {
64        format!("Q{}.{}", self.integer_bits, self.fraction_bits)
65    }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum QFormatError {
70    MissingSignBit,
71    TotalBitsTooWide(u16),
72    AccumulatorNarrower,
73    AccumulatorFractionLoss,
74    AccumulatorRangeLoss,
75}
76
77impl fmt::Display for QFormatError {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        match self {
80            Self::MissingSignBit => write!(f, "integer_bits must include the sign bit"),
81            Self::TotalBitsTooWide(bits) => {
82                write!(f, "Q-format total bits exceed i64 range: {bits}")
83            }
84            Self::AccumulatorNarrower => write!(
85                f,
86                "accumulator format must not be narrower than weight format"
87            ),
88            Self::AccumulatorFractionLoss => {
89                write!(
90                    f,
91                    "accumulator format must preserve weight fractional precision"
92                )
93            }
94            Self::AccumulatorRangeLoss => {
95                write!(f, "accumulator format must cover the full weight range")
96            }
97        }
98    }
99}
100
101impl Error for QFormatError {}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub struct QFormatMixed {
105    pub weight_fmt: QFormat,
106    pub accum_fmt: QFormat,
107    pub scale_per_tensor: bool,
108}
109
110impl QFormatMixed {
111    pub fn q8_8_q16_16() -> Self {
112        Self {
113            weight_fmt: QFormat::q8_8(),
114            accum_fmt: QFormat::q16_16(),
115            scale_per_tensor: true,
116        }
117    }
118
119    pub fn new(
120        weight_fmt: QFormat,
121        accum_fmt: QFormat,
122        scale_per_tensor: bool,
123    ) -> Result<Self, QFormatError> {
124        if accum_fmt.total_bits() < weight_fmt.total_bits() {
125            return Err(QFormatError::AccumulatorNarrower);
126        }
127        if accum_fmt.fraction_bits < weight_fmt.fraction_bits {
128            return Err(QFormatError::AccumulatorFractionLoss);
129        }
130        if accum_fmt.min_value() > weight_fmt.min_value()
131            || accum_fmt.max_value() < weight_fmt.max_value()
132        {
133            return Err(QFormatError::AccumulatorRangeLoss);
134        }
135        Ok(Self {
136            weight_fmt,
137            accum_fmt,
138            scale_per_tensor,
139        })
140    }
141
142    pub fn accumulator_guard_bits(self) -> u8 {
143        self.accum_fmt.total_bits() - self.weight_fmt.total_bits()
144    }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub struct BlockFloatingMode {
149    pub mantissa_bits: u8,
150    pub exponent_bits: u8,
151    pub block_size: usize,
152}
153
154impl BlockFloatingMode {
155    pub fn new(
156        mantissa_bits: u8,
157        exponent_bits: u8,
158        block_size: usize,
159    ) -> Result<Self, BlockFloatingError> {
160        if mantissa_bits < 2 {
161            return Err(BlockFloatingError::MantissaTooNarrow);
162        }
163        if exponent_bits == 0 || exponent_bits > 7 {
164            return Err(BlockFloatingError::InvalidExponentBits);
165        }
166        if block_size == 0 {
167            return Err(BlockFloatingError::EmptyBlock);
168        }
169        Ok(Self {
170            mantissa_bits,
171            exponent_bits,
172            block_size,
173        })
174    }
175
176    pub fn bfp16_e3_x32() -> Self {
177        Self {
178            mantissa_bits: 16,
179            exponent_bits: 3,
180            block_size: 32,
181        }
182    }
183
184    pub fn exponent_bias(self) -> i32 {
185        (1_i32 << (self.exponent_bits - 1)) - 1
186    }
187
188    pub fn min_exponent(self) -> i32 {
189        -self.exponent_bias()
190    }
191
192    pub fn max_exponent(self) -> i32 {
193        ((1_i32 << self.exponent_bits) - 1) - self.exponent_bias()
194    }
195
196    pub fn mantissa_range(self) -> i128 {
197        (1_i128 << (self.mantissa_bits - 1)) - 1
198    }
199
200    pub fn exponent_code_max(self) -> u8 {
201        ((1_u16 << self.exponent_bits) - 1) as u8
202    }
203
204    pub fn block_exponent_count(self, parameter_count: usize) -> Result<usize, BlockFloatingError> {
205        if parameter_count == 0 {
206            return Ok(0);
207        }
208        parameter_count
209            .checked_add(self.block_size - 1)
210            .map(|value| value / self.block_size)
211            .ok_or(BlockFloatingError::ParameterCountOverflow)
212    }
213
214    pub fn block_exponent_layout(
215        self,
216        parameter_count: usize,
217    ) -> Result<BlockExponentLayout, BlockFloatingError> {
218        Ok(BlockExponentLayout {
219            parameter_count,
220            block_size: self.block_size,
221            exponent_count: self.block_exponent_count(parameter_count)?,
222            last_block_size: if parameter_count == 0 {
223                0
224            } else {
225                let remainder = parameter_count % self.block_size;
226                if remainder == 0 {
227                    self.block_size
228                } else {
229                    remainder
230                }
231            },
232        })
233    }
234
235    pub fn validate_exponent_count(
236        self,
237        parameter_count: usize,
238        exponent_count: usize,
239    ) -> Result<(), BlockFloatingError> {
240        let expected = self.block_exponent_count(parameter_count)?;
241        if exponent_count != expected {
242            return Err(BlockFloatingError::ExponentCountMismatch {
243                expected,
244                actual: exponent_count,
245            });
246        }
247        Ok(())
248    }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub struct BlockExponentLayout {
253    pub parameter_count: usize,
254    pub block_size: usize,
255    pub exponent_count: usize,
256    pub last_block_size: usize,
257}
258
259#[derive(Debug, Clone, PartialEq, Eq)]
260pub enum BlockFloatingError {
261    MantissaTooNarrow,
262    InvalidExponentBits,
263    EmptyBlock,
264    ParameterCountOverflow,
265    ExponentCountMismatch { expected: usize, actual: usize },
266}
267
268impl fmt::Display for BlockFloatingError {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        match self {
271            Self::MantissaTooNarrow => write!(f, "mantissa bits must be at least 2"),
272            Self::InvalidExponentBits => write!(f, "exponent bits must be in 1..=7"),
273            Self::EmptyBlock => write!(f, "block size must be positive"),
274            Self::ParameterCountOverflow => write!(f, "parameter count overflows block layout"),
275            Self::ExponentCountMismatch { expected, actual } => {
276                write!(
277                    f,
278                    "exponent count mismatch: expected {expected}, got {actual}"
279                )
280            }
281        }
282    }
283}
284
285impl Error for BlockFloatingError {}
286
287#[derive(Debug, Clone, PartialEq, Eq)]
288pub struct MixedDenseResult {
289    pub outputs_q1616: Vec<i32>,
290    pub overflow: bool,
291    pub overflow_count: usize,
292    pub underflow_count: usize,
293    pub abs_bounds_q1616: Vec<i64>,
294}
295
296#[derive(Debug, Clone, Copy, PartialEq, Eq)]
297pub struct PrecisionTrapReport {
298    pub output_count: usize,
299    pub overflow: bool,
300    pub overflow_count: usize,
301    pub underflow: bool,
302    pub underflow_count: usize,
303    pub saturated_min_count: usize,
304    pub saturated_max_count: usize,
305}
306
307impl PrecisionTrapReport {
308    pub fn from_q1616(
309        outputs_q1616: &[i32],
310        overflow_count: usize,
311        underflow_count: usize,
312    ) -> Self {
313        let saturated_min_count = outputs_q1616
314            .iter()
315            .filter(|&&value| value == i32::MIN)
316            .count();
317        let saturated_max_count = outputs_q1616
318            .iter()
319            .filter(|&&value| value == i32::MAX)
320            .count();
321        Self {
322            output_count: outputs_q1616.len(),
323            overflow: overflow_count > 0,
324            overflow_count,
325            underflow: underflow_count > 0,
326            underflow_count,
327            saturated_min_count,
328            saturated_max_count,
329        }
330    }
331}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub struct PrecisionEnvelopeReport {
335    pub output_count: usize,
336    pub overflow: bool,
337    pub overflow_count: usize,
338    pub underflow: bool,
339    pub underflow_count: usize,
340    pub observed_overflow_free: bool,
341    pub observed_underflow_free: bool,
342    pub conservative_overflow_free: bool,
343    pub max_abs_output_q1616: i64,
344    pub max_abs_bound_q1616: i64,
345    pub conservative_safe_bound_q1616: i64,
346    pub min_headroom_q1616: i64,
347    pub required_total_bits_q1616: u8,
348    pub required_integer_bits_q1616: u8,
349    pub width_headroom_bits_q1616: i16,
350    pub saturation_required: bool,
351    pub static_overflow_proven_safe: bool,
352}
353
354impl MixedDenseResult {
355    pub fn precision_trap_report(&self) -> PrecisionTrapReport {
356        PrecisionTrapReport::from_q1616(
357            &self.outputs_q1616,
358            self.overflow_count,
359            self.underflow_count,
360        )
361    }
362
363    pub fn precision_envelope_report(&self) -> PrecisionEnvelopeReport {
364        let max_abs_output_q1616 = self
365            .outputs_q1616
366            .iter()
367            .map(|&value| abs_i32_to_i64(value))
368            .max()
369            .unwrap_or(0);
370        let max_abs_bound_q1616 = self.abs_bounds_q1616.iter().copied().max().unwrap_or(0);
371        let conservative_safe_bound_q1616 = i64::from(i32::MAX);
372        let min_headroom_q1616 = conservative_safe_bound_q1616.saturating_sub(max_abs_bound_q1616);
373        let required_total_bits_q1616 = required_signed_total_bits(max_abs_bound_q1616);
374        let required_integer_bits_q1616 = required_integer_bits_q1616(required_total_bits_q1616);
375        let width_headroom_bits_q1616 = 32_i16 - i16::from(required_total_bits_q1616);
376        let saturation_required = required_total_bits_q1616 > 32;
377        PrecisionEnvelopeReport {
378            output_count: self.outputs_q1616.len(),
379            overflow: self.overflow,
380            overflow_count: self.overflow_count,
381            underflow: self.underflow_count > 0,
382            underflow_count: self.underflow_count,
383            observed_overflow_free: self.overflow_count == 0,
384            observed_underflow_free: self.underflow_count == 0,
385            conservative_overflow_free: max_abs_bound_q1616 <= conservative_safe_bound_q1616,
386            max_abs_output_q1616,
387            max_abs_bound_q1616,
388            conservative_safe_bound_q1616,
389            min_headroom_q1616,
390            required_total_bits_q1616,
391            required_integer_bits_q1616,
392            width_headroom_bits_q1616,
393            saturation_required,
394            static_overflow_proven_safe: !saturation_required,
395        }
396    }
397}
398
399fn required_signed_total_bits(abs_bound_q1616: i64) -> u8 {
400    if abs_bound_q1616 <= 0 {
401        return 1;
402    }
403    (64 - (abs_bound_q1616 as u64).leading_zeros()) as u8 + 1
404}
405
406fn required_integer_bits_q1616(required_total_bits_q1616: u8) -> u8 {
407    required_total_bits_q1616.saturating_sub(16).max(1)
408}
409
410fn abs_i32_to_i64(value: i32) -> i64 {
411    if value == i32::MIN {
412        i64::from(i32::MAX) + 1
413    } else {
414        i64::from(value.abs())
415    }
416}
417
418fn i128_to_i64_saturating(value: i128) -> i64 {
419    if value > i128::from(i64::MAX) {
420        i64::MAX
421    } else if value < i128::from(i64::MIN) {
422        i64::MIN
423    } else {
424        value as i64
425    }
426}
427
428#[derive(Debug, Clone, PartialEq, Eq)]
429pub enum MixedDenseError {
430    EmptyShape,
431    ShapeOverflow,
432    WeightLengthMismatch { expected: usize, actual: usize },
433    InputLengthMismatch { expected: usize, actual: usize },
434}
435
436impl fmt::Display for MixedDenseError {
437    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438        match self {
439            Self::EmptyShape => write!(f, "dense shape must have positive inputs and outputs"),
440            Self::ShapeOverflow => write!(f, "dense shape overflows addressable memory"),
441            Self::WeightLengthMismatch { expected, actual } => {
442                write!(
443                    f,
444                    "weight length mismatch: expected {expected}, got {actual}"
445                )
446            }
447            Self::InputLengthMismatch { expected, actual } => {
448                write!(
449                    f,
450                    "input length mismatch: expected {expected}, got {actual}"
451                )
452            }
453        }
454    }
455}
456
457impl Error for MixedDenseError {}
458
459#[derive(Debug, Clone, PartialEq, Eq)]
460pub enum BlockFloatingDenseError {
461    EmptyShape,
462    ShapeOverflow,
463    MantissaLengthMismatch { expected: usize, actual: usize },
464    ExponentLengthMismatch { expected: usize, actual: usize },
465    InputLengthMismatch { expected: usize, actual: usize },
466    MantissaOutOfRange { index: usize, value: i16 },
467    ExponentOutOfRange { index: usize, value: u8 },
468}
469
470impl fmt::Display for BlockFloatingDenseError {
471    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
472        match self {
473            Self::EmptyShape => write!(f, "dense shape must have positive inputs and outputs"),
474            Self::ShapeOverflow => write!(f, "dense shape overflows addressable memory"),
475            Self::MantissaLengthMismatch { expected, actual } => {
476                write!(
477                    f,
478                    "mantissa length mismatch: expected {expected}, got {actual}"
479                )
480            }
481            Self::ExponentLengthMismatch { expected, actual } => {
482                write!(
483                    f,
484                    "exponent length mismatch: expected {expected}, got {actual}"
485                )
486            }
487            Self::InputLengthMismatch { expected, actual } => {
488                write!(
489                    f,
490                    "input length mismatch: expected {expected}, got {actual}"
491                )
492            }
493            Self::MantissaOutOfRange { index, value } => {
494                write!(
495                    f,
496                    "mantissa at index {index} exceeds configured range: {value}"
497                )
498            }
499            Self::ExponentOutOfRange { index, value } => {
500                write!(
501                    f,
502                    "exponent at index {index} exceeds configured range: {value}"
503                )
504            }
505        }
506    }
507}
508
509impl Error for BlockFloatingDenseError {}
510
511pub fn mixed_dense_q88_q1616(
512    weights_q88: &[i16],
513    inputs_q1616: &[i32],
514    n_outputs: usize,
515    n_inputs: usize,
516) -> Result<MixedDenseResult, MixedDenseError> {
517    if n_inputs == 0 || n_outputs == 0 {
518        return Err(MixedDenseError::EmptyShape);
519    }
520    let expected_weights = n_outputs
521        .checked_mul(n_inputs)
522        .ok_or(MixedDenseError::ShapeOverflow)?;
523    if weights_q88.len() != expected_weights {
524        return Err(MixedDenseError::WeightLengthMismatch {
525            expected: expected_weights,
526            actual: weights_q88.len(),
527        });
528    }
529    if inputs_q1616.len() != n_inputs {
530        return Err(MixedDenseError::InputLengthMismatch {
531            expected: n_inputs,
532            actual: inputs_q1616.len(),
533        });
534    }
535
536    let mut outputs_q1616 = Vec::with_capacity(n_outputs);
537    let mut abs_bounds_q1616 = Vec::with_capacity(n_outputs);
538    let mut overflow_count = 0_usize;
539    let mut underflow_count = 0_usize;
540    for output_idx in 0..n_outputs {
541        let mut sum: i128 = 0;
542        let mut abs_bound: i128 = 0;
543        let row_start = output_idx * n_inputs;
544        for input_idx in 0..n_inputs {
545            let weight = i128::from(weights_q88[row_start + input_idx]);
546            let input = i128::from(inputs_q1616[input_idx]);
547            sum += weight * input;
548            abs_bound += weight.abs() * input.abs();
549        }
550        let scaled = sum >> 8;
551        let scaled_bound = (abs_bound + ((1_i128 << 8) - 1)) >> 8;
552        abs_bounds_q1616.push(i128_to_i64_saturating(scaled_bound));
553        if scaled > i128::from(i32::MAX) {
554            outputs_q1616.push(i32::MAX);
555            overflow_count += 1;
556        } else if scaled < i128::from(i32::MIN) {
557            outputs_q1616.push(i32::MIN);
558            overflow_count += 1;
559        } else {
560            if sum != 0 && scaled == 0 {
561                underflow_count += 1;
562            }
563            outputs_q1616.push(scaled as i32);
564        }
565    }
566
567    Ok(MixedDenseResult {
568        outputs_q1616,
569        overflow: overflow_count > 0,
570        overflow_count,
571        underflow_count,
572        abs_bounds_q1616,
573    })
574}
575
576/// Per-element results of a batched mixed-precision Q8.8 × Q16.16 dense MAC.
577///
578/// Each vector is row-major `n_batch * n_outputs`; element `(b, o)` lives at
579/// index `b * n_outputs + o`.
580#[derive(Debug, Clone, PartialEq, Eq)]
581pub struct MixedDenseBatchResult {
582    /// Saturated Q16.16 accumulator codes.
583    pub outputs_q1616: Vec<i32>,
584    /// `true` where the accumulator left the Q16.16 range.
585    pub overflow: Vec<bool>,
586    /// `true` where a non-zero contraction rounded to zero without overflowing.
587    pub underflow: Vec<bool>,
588}
589
590/// Batched integer mixed-precision Q8.8 × Q16.16 dense MAC.
591///
592/// `weights_q88` is a row-major `n_outputs * n_inputs` Q8.8 matrix; `inputs_q1616`
593/// is a row-major `n_batch * n_inputs` Q16.16 code buffer. Each output divides the
594/// integer contraction by the Q8.8 weight scale (an arithmetic shift, i.e. floor
595/// division) and saturates to the Q16.16 code range, matching the Python floor and
596/// the Julia/Go/Mojo backends bit-for-bit.
597pub fn mixed_dense_forward_batch_q88_q1616(
598    weights_q88: &[i16],
599    inputs_q1616: &[i32],
600    n_outputs: usize,
601    n_inputs: usize,
602) -> Result<MixedDenseBatchResult, MixedDenseError> {
603    if n_inputs == 0 || n_outputs == 0 {
604        return Err(MixedDenseError::EmptyShape);
605    }
606    let expected_weights = n_outputs
607        .checked_mul(n_inputs)
608        .ok_or(MixedDenseError::ShapeOverflow)?;
609    if weights_q88.len() != expected_weights {
610        return Err(MixedDenseError::WeightLengthMismatch {
611            expected: expected_weights,
612            actual: weights_q88.len(),
613        });
614    }
615    if inputs_q1616.is_empty() || !inputs_q1616.len().is_multiple_of(n_inputs) {
616        return Err(MixedDenseError::InputLengthMismatch {
617            expected: n_inputs,
618            actual: inputs_q1616.len(),
619        });
620    }
621
622    let n_batch = inputs_q1616.len() / n_inputs;
623    let count = n_batch * n_outputs;
624    let mut outputs_q1616 = Vec::with_capacity(count);
625    let mut overflow = Vec::with_capacity(count);
626    let mut underflow = Vec::with_capacity(count);
627    for batch_idx in 0..n_batch {
628        let input_row = &inputs_q1616[batch_idx * n_inputs..(batch_idx + 1) * n_inputs];
629        for output_idx in 0..n_outputs {
630            let weight_row = &weights_q88[output_idx * n_inputs..(output_idx + 1) * n_inputs];
631            let mut sum: i128 = 0;
632            for input_idx in 0..n_inputs {
633                sum += i128::from(weight_row[input_idx]) * i128::from(input_row[input_idx]);
634            }
635            let scaled = sum >> 8;
636            if scaled > i128::from(i32::MAX) {
637                outputs_q1616.push(i32::MAX);
638                overflow.push(true);
639                underflow.push(false);
640            } else if scaled < i128::from(i32::MIN) {
641                outputs_q1616.push(i32::MIN);
642                overflow.push(true);
643                underflow.push(false);
644            } else {
645                outputs_q1616.push(scaled as i32);
646                overflow.push(false);
647                underflow.push(sum != 0 && scaled == 0);
648            }
649        }
650    }
651    Ok(MixedDenseBatchResult {
652        outputs_q1616,
653        overflow,
654        underflow,
655    })
656}
657
658pub fn block_floating_dense_q16(
659    mantissas: &[i16],
660    exponents: &[u8],
661    inputs_q1616: &[i32],
662    n_outputs: usize,
663    n_inputs: usize,
664    mode: BlockFloatingMode,
665) -> Result<MixedDenseResult, BlockFloatingDenseError> {
666    if n_inputs == 0 || n_outputs == 0 {
667        return Err(BlockFloatingDenseError::EmptyShape);
668    }
669    let expected_weights = n_outputs
670        .checked_mul(n_inputs)
671        .ok_or(BlockFloatingDenseError::ShapeOverflow)?;
672    let expected_blocks = mode
673        .block_exponent_count(expected_weights)
674        .map_err(|_| BlockFloatingDenseError::ShapeOverflow)?;
675
676    if mantissas.len() != expected_weights {
677        return Err(BlockFloatingDenseError::MantissaLengthMismatch {
678            expected: expected_weights,
679            actual: mantissas.len(),
680        });
681    }
682    if exponents.len() != expected_blocks {
683        return Err(BlockFloatingDenseError::ExponentLengthMismatch {
684            expected: expected_blocks,
685            actual: exponents.len(),
686        });
687    }
688    if inputs_q1616.len() != n_inputs {
689        return Err(BlockFloatingDenseError::InputLengthMismatch {
690            expected: n_inputs,
691            actual: inputs_q1616.len(),
692        });
693    }
694
695    let mantissa_range = mode.mantissa_range();
696    for (index, &mantissa) in mantissas.iter().enumerate() {
697        if i128::from(mantissa).abs() > mantissa_range {
698            return Err(BlockFloatingDenseError::MantissaOutOfRange {
699                index,
700                value: mantissa,
701            });
702        }
703    }
704    let exponent_code_max = mode.exponent_code_max();
705    for (index, &exponent) in exponents.iter().enumerate() {
706        if exponent > exponent_code_max {
707            return Err(BlockFloatingDenseError::ExponentOutOfRange {
708                index,
709                value: exponent,
710            });
711        }
712    }
713
714    let mut outputs_q1616 = Vec::with_capacity(n_outputs);
715    let mut abs_bounds_q1616 = Vec::with_capacity(n_outputs);
716    let mut overflow_count = 0_usize;
717    let mut underflow_count = 0_usize;
718    for output_idx in 0..n_outputs {
719        let mut sum: i128 = 0;
720        let mut abs_bound: i128 = 0;
721        let mut dropped_sub_lsb_product = false;
722        let row_start = output_idx * n_inputs;
723        for input_idx in 0..n_inputs {
724            let linear_idx = row_start + input_idx;
725            let block_idx = linear_idx / mode.block_size;
726            let product = i128::from(mantissas[linear_idx]) * i128::from(inputs_q1616[input_idx]);
727            let shift = i32::from(exponents[block_idx]) - mode.exponent_bias();
728            if shift >= 0 {
729                sum += product << shift;
730                abs_bound += product.abs() << shift;
731            } else {
732                sum += product >> (-shift);
733                let divisor_shift = -shift;
734                if product != 0 && (product >> divisor_shift) == 0 {
735                    dropped_sub_lsb_product = true;
736                }
737                abs_bound += (product.abs() + ((1_i128 << divisor_shift) - 1)) >> divisor_shift;
738            }
739        }
740        abs_bounds_q1616.push(i128_to_i64_saturating(abs_bound));
741        if sum > i128::from(i32::MAX) {
742            outputs_q1616.push(i32::MAX);
743            overflow_count += 1;
744        } else if sum < i128::from(i32::MIN) {
745            outputs_q1616.push(i32::MIN);
746            overflow_count += 1;
747        } else {
748            if sum == 0 && dropped_sub_lsb_product {
749                underflow_count += 1;
750            }
751            outputs_q1616.push(sum as i32);
752        }
753    }
754
755    Ok(MixedDenseResult {
756        outputs_q1616,
757        overflow: overflow_count > 0,
758        overflow_count,
759        underflow_count,
760        abs_bounds_q1616,
761    })
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767
768    #[test]
769    fn mixed_dense_batch_matches_single_per_output() {
770        let weights = [256_i16, -128, 64, 512];
771        let inputs = [512_i32, 1024, 256, 768];
772        let batch = mixed_dense_forward_batch_q88_q1616(&weights, &inputs, 2, 2).unwrap();
773        // n_batch = 2, n_outputs = 2.
774        assert_eq!(batch.outputs_q1616.len(), 4);
775        for batch_idx in 0..2 {
776            let row = &inputs[batch_idx * 2..batch_idx * 2 + 2];
777            let single = mixed_dense_q88_q1616(&weights, row, 2, 2).unwrap();
778            for output_idx in 0..2 {
779                assert_eq!(
780                    batch.outputs_q1616[batch_idx * 2 + output_idx],
781                    single.outputs_q1616[output_idx]
782                );
783            }
784        }
785    }
786
787    #[test]
788    fn mixed_dense_batch_floor_division_is_signed() {
789        // raw = -1 -> -1 >> 8 = -1 (floor, not truncation toward zero).
790        let batch = mixed_dense_forward_batch_q88_q1616(&[1], &[-1], 1, 1).unwrap();
791        assert_eq!(batch.outputs_q1616, vec![-1]);
792        assert!(!batch.overflow[0]);
793        assert!(!batch.underflow[0]);
794    }
795
796    #[test]
797    fn mixed_dense_batch_flags_overflow_and_underflow() {
798        let weights = [i16::MAX; 4];
799        let inputs = [2_000_000_000_i32, 2_000_000_000, 1, 1];
800        let batch = mixed_dense_forward_batch_q88_q1616(&weights, &inputs, 1, 4).unwrap();
801        assert!(batch.overflow[0]);
802        assert_eq!(batch.outputs_q1616[0], i32::MAX);
803        // A tiny non-zero contraction that rounds to zero is an underflow.
804        let under = mixed_dense_forward_batch_q88_q1616(&[1], &[1], 1, 1).unwrap();
805        assert_eq!(under.outputs_q1616, vec![0]);
806        assert!(under.underflow[0]);
807        assert!(!under.overflow[0]);
808    }
809
810    #[test]
811    fn mixed_dense_batch_rejects_bad_shapes() {
812        assert_eq!(
813            mixed_dense_forward_batch_q88_q1616(&[1], &[1], 0, 1).unwrap_err(),
814            MixedDenseError::EmptyShape
815        );
816        assert!(matches!(
817            mixed_dense_forward_batch_q88_q1616(&[1, 1], &[1], 1, 1).unwrap_err(),
818            MixedDenseError::WeightLengthMismatch { .. }
819        ));
820        assert!(matches!(
821            mixed_dense_forward_batch_q88_q1616(&[1, 1], &[1, 1, 1], 1, 2).unwrap_err(),
822            MixedDenseError::InputLengthMismatch { .. }
823        ));
824    }
825
826    #[test]
827    fn qformat_mixed_default_matches_python_contract() {
828        let fmt = QFormatMixed::q8_8_q16_16();
829
830        assert_eq!(fmt.weight_fmt.label(), "Q8.8");
831        assert_eq!(fmt.accum_fmt.label(), "Q16.16");
832        assert_eq!(fmt.accumulator_guard_bits(), 16);
833    }
834
835    #[test]
836    fn rejects_accumulator_precision_loss() {
837        let result = QFormatMixed::new(
838            QFormat::new(8, 12).unwrap(),
839            QFormat::new(16, 8).unwrap(),
840            true,
841        );
842
843        assert_eq!(result.unwrap_err(), QFormatError::AccumulatorFractionLoss);
844    }
845
846    #[test]
847    fn mixed_dense_matches_manual_q88_q1616_codes() {
848        let weights = [128_i16, -64_i16, 256_i16, 32_i16];
849        let inputs = [32768_i32, -16384_i32];
850
851        let result = mixed_dense_q88_q1616(&weights, &inputs, 2, 2).unwrap();
852
853        assert_eq!(result.outputs_q1616, vec![20480, 30720]);
854        assert!(!result.overflow);
855        assert_eq!(result.overflow_count, 0);
856        assert_eq!(result.underflow_count, 0);
857        assert_eq!(result.abs_bounds_q1616, vec![20480, 34816]);
858
859        let envelope = result.precision_envelope_report();
860        assert!(envelope.observed_overflow_free);
861        assert!(envelope.observed_underflow_free);
862        assert!(envelope.conservative_overflow_free);
863        assert_eq!(envelope.max_abs_output_q1616, 30720);
864        assert_eq!(envelope.max_abs_bound_q1616, 34816);
865        assert_eq!(envelope.required_total_bits_q1616, 17);
866        assert_eq!(envelope.required_integer_bits_q1616, 1);
867        assert_eq!(envelope.width_headroom_bits_q1616, 15);
868        assert!(!envelope.saturation_required);
869        assert!(envelope.static_overflow_proven_safe);
870    }
871
872    #[test]
873    fn mixed_dense_negative_products_follow_arithmetic_shift() {
874        let result = mixed_dense_q88_q1616(&[128_i16], &[-1_i32], 1, 1).unwrap();
875
876        assert_eq!(result.outputs_q1616, vec![-1]);
877    }
878
879    #[test]
880    fn mixed_dense_reports_sub_lsb_underflow() {
881        let result = mixed_dense_q88_q1616(&[1_i16], &[1_i32], 1, 1).unwrap();
882
883        assert_eq!(result.outputs_q1616, vec![0]);
884        assert_eq!(result.overflow_count, 0);
885        assert_eq!(result.underflow_count, 1);
886
887        let report = result.precision_trap_report();
888        assert!(!report.overflow);
889        assert!(report.underflow);
890        assert_eq!(report.underflow_count, 1);
891
892        let envelope = result.precision_envelope_report();
893        assert!(envelope.observed_overflow_free);
894        assert!(!envelope.observed_underflow_free);
895    }
896
897    #[test]
898    fn mixed_dense_saturates_overflow() {
899        let weights = [i16::MAX, i16::MAX];
900        let inputs = [i32::MAX, i32::MAX];
901
902        let result = mixed_dense_q88_q1616(&weights, &inputs, 1, 2).unwrap();
903
904        assert_eq!(result.outputs_q1616, vec![i32::MAX]);
905        assert!(result.overflow);
906        assert_eq!(result.overflow_count, 1);
907        assert_eq!(result.underflow_count, 0);
908
909        let report = result.precision_trap_report();
910        assert_eq!(report.output_count, 1);
911        assert!(report.overflow);
912        assert_eq!(report.overflow_count, 1);
913        assert!(!report.underflow);
914        assert_eq!(report.underflow_count, 0);
915        assert_eq!(report.saturated_max_count, 1);
916        assert_eq!(report.saturated_min_count, 0);
917
918        let envelope = result.precision_envelope_report();
919        assert!(!envelope.observed_overflow_free);
920        assert!(envelope.observed_underflow_free);
921        assert!(!envelope.conservative_overflow_free);
922        assert_eq!(envelope.output_count, 1);
923        assert_eq!(envelope.overflow_count, 1);
924        assert_eq!(envelope.underflow_count, 0);
925        assert!(envelope.max_abs_bound_q1616 > envelope.conservative_safe_bound_q1616);
926        assert!(envelope.saturation_required);
927        assert!(!envelope.static_overflow_proven_safe);
928    }
929
930    #[test]
931    fn mixed_dense_rejects_shape_mismatches() {
932        assert_eq!(
933            mixed_dense_q88_q1616(&[], &[1], 1, 0).unwrap_err(),
934            MixedDenseError::EmptyShape
935        );
936        assert_eq!(
937            mixed_dense_q88_q1616(&[1], &[1], 2, 1).unwrap_err(),
938            MixedDenseError::WeightLengthMismatch {
939                expected: 2,
940                actual: 1,
941            }
942        );
943        assert_eq!(
944            mixed_dense_q88_q1616(&[1, 2], &[1], 1, 2).unwrap_err(),
945            MixedDenseError::InputLengthMismatch {
946                expected: 2,
947                actual: 1,
948            }
949        );
950    }
951
952    #[test]
953    fn block_floating_mode_reports_full_exponent_range() {
954        let mode = BlockFloatingMode::new(8, 2, 2).unwrap();
955
956        assert_eq!(mode.exponent_bias(), 1);
957        assert_eq!(mode.min_exponent(), -1);
958        assert_eq!(mode.max_exponent(), 2);
959        assert_eq!(mode.exponent_code_max(), 3);
960    }
961
962    #[test]
963    fn block_floating_mode_computes_exponent_layout() {
964        let mode = BlockFloatingMode::new(16, 3, 32).unwrap();
965        let layout = mode.block_exponent_layout(65).unwrap();
966
967        assert_eq!(layout.parameter_count, 65);
968        assert_eq!(layout.block_size, 32);
969        assert_eq!(layout.exponent_count, 3);
970        assert_eq!(layout.last_block_size, 1);
971        assert_eq!(mode.block_exponent_count(0).unwrap(), 0);
972        assert_eq!(
973            mode.validate_exponent_count(65, 2).unwrap_err(),
974            BlockFloatingError::ExponentCountMismatch {
975                expected: 3,
976                actual: 2,
977            }
978        );
979    }
980
981    #[test]
982    fn block_floating_dense_matches_manual_shifted_products() {
983        let mode = BlockFloatingMode::new(16, 3, 2).unwrap();
984        let bias = mode.exponent_bias() as u8;
985        let mantissas = [2_i16, -4_i16, 8_i16, 16_i16];
986        let exponents = [bias, bias - 1];
987        let inputs = [32768_i32, -16384_i32];
988
989        let result = block_floating_dense_q16(&mantissas, &exponents, &inputs, 2, 2, mode).unwrap();
990
991        assert_eq!(result.outputs_q1616, vec![131072, 0]);
992        assert!(!result.overflow);
993        assert_eq!(result.underflow_count, 0);
994        assert_eq!(result.abs_bounds_q1616, vec![131072, 262144]);
995
996        let envelope = result.precision_envelope_report();
997        assert!(envelope.observed_overflow_free);
998        assert!(envelope.observed_underflow_free);
999        assert!(envelope.conservative_overflow_free);
1000        assert_eq!(envelope.max_abs_output_q1616, 131072);
1001        assert_eq!(envelope.max_abs_bound_q1616, 262144);
1002    }
1003
1004    #[test]
1005    fn block_floating_dense_seeded_exponent_edges_match_manual_q1616_codes() {
1006        let mode = BlockFloatingMode::new(16, 3, 2).unwrap();
1007        let mantissas = [
1008            1_i16,
1009            -2_i16,
1010            i16::MAX,
1011            -i16::MAX,
1012            -3_i16,
1013            4_i16,
1014            -i16::MAX,
1015            i16::MAX,
1016        ];
1017        let exponents = [
1018            0_u8,
1019            mode.exponent_code_max(),
1020            0_u8,
1021            mode.exponent_code_max(),
1022        ];
1023        let inputs = [32768_i32, -16384_i32, 1_i32, -1_i32];
1024
1025        let result = block_floating_dense_q16(&mantissas, &exponents, &inputs, 2, 4, mode)
1026            .expect("seeded exponent-edge dimensions are valid");
1027
1028        assert_eq!(result.outputs_q1616, vec![1_056_736, -1_069_024]);
1029        assert_eq!(result.overflow_count, 0);
1030        assert_eq!(result.underflow_count, 0);
1031        assert_eq!(result.abs_bounds_q1616, vec![1_056_736, 1_069_024]);
1032
1033        let envelope = result.precision_envelope_report();
1034        assert!(envelope.observed_overflow_free);
1035        assert!(envelope.observed_underflow_free);
1036        assert!(envelope.conservative_overflow_free);
1037        assert_eq!(envelope.max_abs_bound_q1616, 1_069_024);
1038        assert_eq!(envelope.min_headroom_q1616, 2_146_414_623);
1039    }
1040
1041    #[test]
1042    fn block_floating_dense_max_exponent_edge_saturates_and_reports_trap() {
1043        let mode = BlockFloatingMode::new(16, 3, 2).unwrap();
1044        let mantissas = [i16::MAX, i16::MAX];
1045        let exponents = [mode.exponent_code_max()];
1046        let inputs = [32767_i32 << 16, 32767_i32 << 16];
1047
1048        let result = block_floating_dense_q16(&mantissas, &exponents, &inputs, 1, 2, mode)
1049            .expect("max-exponent trap dimensions are valid");
1050
1051        assert_eq!(result.outputs_q1616, vec![i32::MAX]);
1052        assert!(result.overflow);
1053        assert_eq!(result.overflow_count, 1);
1054        assert_eq!(result.underflow_count, 0);
1055
1056        let report = result.precision_trap_report();
1057        assert!(report.overflow);
1058        assert_eq!(report.overflow_count, 1);
1059        assert!(!report.underflow);
1060        assert_eq!(report.saturated_max_count, 1);
1061
1062        let envelope = result.precision_envelope_report();
1063        assert!(!envelope.observed_overflow_free);
1064        assert!(envelope.observed_underflow_free);
1065        assert!(!envelope.conservative_overflow_free);
1066        assert!(envelope.max_abs_bound_q1616 > envelope.conservative_safe_bound_q1616);
1067    }
1068
1069    #[test]
1070    fn block_floating_dense_reports_sub_lsb_underflow() {
1071        let mode = BlockFloatingMode::new(16, 3, 1).unwrap();
1072        let result = block_floating_dense_q16(&[1_i16], &[0_u8], &[1_i32], 1, 1, mode).unwrap();
1073
1074        assert_eq!(result.outputs_q1616, vec![0]);
1075        assert_eq!(result.overflow_count, 0);
1076        assert_eq!(result.underflow_count, 1);
1077
1078        let report = result.precision_trap_report();
1079        assert!(!report.overflow);
1080        assert!(report.underflow);
1081        assert_eq!(report.underflow_count, 1);
1082
1083        let envelope = result.precision_envelope_report();
1084        assert!(envelope.observed_overflow_free);
1085        assert!(!envelope.observed_underflow_free);
1086        assert_eq!(envelope.max_abs_bound_q1616, 1);
1087    }
1088
1089    #[test]
1090    fn block_floating_dense_saturates_large_outputs() {
1091        let mode = BlockFloatingMode::bfp16_e3_x32();
1092        let mantissas = vec![i16::MAX; 64];
1093        let exponents = vec![mode.exponent_code_max(); 2];
1094        let inputs = vec![i32::MAX; 64];
1095
1096        let result =
1097            block_floating_dense_q16(&mantissas, &exponents, &inputs, 1, 64, mode).unwrap();
1098
1099        assert_eq!(result.outputs_q1616, vec![i32::MAX]);
1100        assert!(result.overflow);
1101        assert_eq!(result.overflow_count, 1);
1102        assert_eq!(result.underflow_count, 0);
1103
1104        let report = result.precision_trap_report();
1105        assert_eq!(report.output_count, 1);
1106        assert!(report.overflow);
1107        assert_eq!(report.overflow_count, 1);
1108        assert!(!report.underflow);
1109        assert_eq!(report.underflow_count, 0);
1110        assert_eq!(report.saturated_max_count, 1);
1111        assert_eq!(report.saturated_min_count, 0);
1112
1113        let envelope = result.precision_envelope_report();
1114        assert!(!envelope.observed_overflow_free);
1115        assert!(envelope.observed_underflow_free);
1116        assert!(!envelope.conservative_overflow_free);
1117        assert_eq!(envelope.output_count, 1);
1118        assert_eq!(envelope.overflow_count, 1);
1119        assert_eq!(envelope.underflow_count, 0);
1120        assert!(envelope.max_abs_bound_q1616 > envelope.conservative_safe_bound_q1616);
1121    }
1122
1123    #[test]
1124    fn block_floating_dense_rejects_invalid_lengths_and_ranges() {
1125        let mode = BlockFloatingMode::new(8, 2, 2).unwrap();
1126
1127        assert_eq!(
1128            block_floating_dense_q16(&[], &[1], &[1], 1, 0, mode).unwrap_err(),
1129            BlockFloatingDenseError::EmptyShape
1130        );
1131        assert_eq!(
1132            block_floating_dense_q16(&[1], &[1], &[1], 2, 1, mode).unwrap_err(),
1133            BlockFloatingDenseError::MantissaLengthMismatch {
1134                expected: 2,
1135                actual: 1,
1136            }
1137        );
1138        assert_eq!(
1139            block_floating_dense_q16(&[1, 2], &[], &[1, 2], 1, 2, mode).unwrap_err(),
1140            BlockFloatingDenseError::ExponentLengthMismatch {
1141                expected: 1,
1142                actual: 0,
1143            }
1144        );
1145        assert_eq!(
1146            block_floating_dense_q16(&[128, 0], &[1], &[1, 2], 1, 2, mode).unwrap_err(),
1147            BlockFloatingDenseError::MantissaOutOfRange {
1148                index: 0,
1149                value: 128,
1150            }
1151        );
1152    }
1153}
1154
1155#[cfg(test)]
1156mod block_floating_benchmark_contract_tests {
1157    use super::*;
1158
1159    const N_INPUTS: usize = 64;
1160    const N_OUTPUTS: usize = 32;
1161
1162    fn round_div_nearest_even(value: i32, divisor: i32) -> i16 {
1163        let sign = if value < 0 { -1 } else { 1 };
1164        let magnitude = value.abs();
1165        let quotient = magnitude / divisor;
1166        let remainder = magnitude % divisor;
1167        let rounded_magnitude = if remainder * 2 < divisor {
1168            quotient
1169        } else if remainder * 2 > divisor {
1170            quotient + 1
1171        } else if quotient % 2 == 0 {
1172            quotient
1173        } else {
1174            quotient + 1
1175        };
1176        (sign * rounded_magnitude) as i16
1177    }
1178
1179    #[test]
1180    fn block_floating_benchmark_matches_python_quantiser_envelope() {
1181        let mode = BlockFloatingMode::bfp16_e3_x32();
1182        let mantissas = (0..(N_INPUTS * N_OUTPUTS))
1183            .map(|idx| {
1184                let raw_weight_code = ((idx * 23 + 3) % 1025) as i32 - 512;
1185                round_div_nearest_even(raw_weight_code, 64)
1186            })
1187            .collect::<Vec<_>>();
1188        let exponents = vec![0_u8; (N_INPUTS * N_OUTPUTS + mode.block_size - 1) / mode.block_size];
1189        let inputs = (0..N_INPUTS)
1190            .map(|idx| (((idx * 19 + 5) % 257) as i32 - 128) << 8)
1191            .collect::<Vec<_>>();
1192
1193        let result =
1194            block_floating_dense_q16(&mantissas, &exponents, &inputs, N_OUTPUTS, N_INPUTS, mode)
1195                .expect("benchmark contract dimensions are valid");
1196        let envelope = result.precision_envelope_report();
1197
1198        assert_eq!(result.overflow_count, 0);
1199        assert_eq!(envelope.max_abs_bound_q1616, 610_816);
1200        assert!(envelope.conservative_overflow_free);
1201
1202        let saturating_mantissas = vec![16_384_i16; N_INPUTS * N_OUTPUTS];
1203        let saturating_exponents =
1204            vec![2_u8; (N_INPUTS * N_OUTPUTS + mode.block_size - 1) / mode.block_size];
1205        let saturating_inputs = vec![32767_i32 << 16; N_INPUTS];
1206        let saturating_result = block_floating_dense_q16(
1207            &saturating_mantissas,
1208            &saturating_exponents,
1209            &saturating_inputs,
1210            N_OUTPUTS,
1211            N_INPUTS,
1212            mode,
1213        )
1214        .expect("saturating benchmark contract dimensions are valid");
1215        let saturating_envelope = saturating_result.precision_envelope_report();
1216
1217        assert_eq!(saturating_result.overflow_count, N_OUTPUTS);
1218        assert_eq!(
1219            saturating_envelope.max_abs_bound_q1616,
1220            1_125_865_547_104_256
1221        );
1222        assert!(!saturating_envelope.conservative_overflow_free);
1223    }
1224}
1225#[cfg(test)]
1226mod mixed_dense_benchmark_contract_tests {
1227    use super::*;
1228
1229    #[test]
1230    fn mixed_dense_benchmark_contract_matches_python_envelope() {
1231        const N_INPUTS: usize = 64;
1232        const N_OUTPUTS: usize = 32;
1233
1234        let weights = (0..(N_INPUTS * N_OUTPUTS))
1235            .map(|idx| (((idx * 17 + 11) % 513) as i32 - 256) as i16)
1236            .collect::<Vec<_>>();
1237        let inputs = (0..N_INPUTS)
1238            .map(|idx| (((idx as i32 * 19 + 5) % 257) - 128) << 8)
1239            .collect::<Vec<_>>();
1240        let safe = mixed_dense_q88_q1616(&weights, &inputs, N_OUTPUTS, N_INPUTS)
1241            .expect("benchmark contract dimensions must be valid");
1242        let safe_envelope = safe.precision_envelope_report();
1243
1244        assert_eq!(safe.overflow_count, 0);
1245        assert_eq!(safe_envelope.max_abs_bound_q1616, 531_400);
1246        assert!(safe_envelope.conservative_overflow_free);
1247        assert_eq!(safe_envelope.min_headroom_q1616, 2_146_952_247);
1248        assert_eq!(safe_envelope.required_total_bits_q1616, 21);
1249        assert_eq!(safe_envelope.required_integer_bits_q1616, 5);
1250        assert_eq!(safe_envelope.width_headroom_bits_q1616, 11);
1251        assert!(!safe_envelope.saturation_required);
1252
1253        let probe_weights = vec![127_i16 << 8; N_INPUTS * N_OUTPUTS];
1254        let probe_inputs = vec![32767_i32 << 16; N_INPUTS];
1255        let probe = mixed_dense_q88_q1616(&probe_weights, &probe_inputs, N_OUTPUTS, N_INPUTS)
1256            .expect("saturating probe dimensions must be valid");
1257        let probe_envelope = probe.precision_envelope_report();
1258
1259        assert_eq!(probe.overflow_count, N_OUTPUTS);
1260        assert_eq!(probe_envelope.max_abs_bound_q1616, 17_454_214_414_336);
1261        assert!(!probe_envelope.conservative_overflow_free);
1262        assert_eq!(probe_envelope.required_total_bits_q1616, 45);
1263        assert_eq!(probe_envelope.required_integer_bits_q1616, 29);
1264        assert_eq!(probe_envelope.width_headroom_bits_q1616, -13);
1265        assert!(probe_envelope.saturation_required);
1266    }
1267}