1use std::error::Error;
10use std::fmt;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct QFormat {
14 pub integer_bits: u8,
15 pub fraction_bits: u8,
16}
17
18impl QFormat {
19 pub const fn q8_8() -> Self {
20 Self {
21 integer_bits: 8,
22 fraction_bits: 8,
23 }
24 }
25
26 pub const fn q16_16() -> Self {
27 Self {
28 integer_bits: 16,
29 fraction_bits: 16,
30 }
31 }
32
33 pub fn new(integer_bits: u8, fraction_bits: u8) -> Result<Self, QFormatError> {
34 if integer_bits == 0 {
35 return Err(QFormatError::MissingSignBit);
36 }
37 let total_bits = u16::from(integer_bits) + u16::from(fraction_bits);
38 if total_bits == 0 || total_bits > 63 {
39 return Err(QFormatError::TotalBitsTooWide(total_bits));
40 }
41 Ok(Self {
42 integer_bits,
43 fraction_bits,
44 })
45 }
46
47 pub fn total_bits(self) -> u8 {
48 self.integer_bits + self.fraction_bits
49 }
50
51 pub fn scale(self) -> i128 {
52 1_i128 << self.fraction_bits
53 }
54
55 pub fn min_value(self) -> f64 {
56 -((1_i128 << (self.total_bits() - 1)) as f64) / self.scale() as f64
57 }
58
59 pub fn max_value(self) -> f64 {
60 ((1_i128 << (self.total_bits() - 1)) - 1) as f64 / self.scale() as f64
61 }
62
63 pub fn label(self) -> String {
64 format!("Q{}.{}", self.integer_bits, self.fraction_bits)
65 }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum QFormatError {
70 MissingSignBit,
71 TotalBitsTooWide(u16),
72 AccumulatorNarrower,
73 AccumulatorFractionLoss,
74 AccumulatorRangeLoss,
75}
76
77impl fmt::Display for QFormatError {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 match self {
80 Self::MissingSignBit => write!(f, "integer_bits must include the sign bit"),
81 Self::TotalBitsTooWide(bits) => {
82 write!(f, "Q-format total bits exceed i64 range: {bits}")
83 }
84 Self::AccumulatorNarrower => write!(
85 f,
86 "accumulator format must not be narrower than weight format"
87 ),
88 Self::AccumulatorFractionLoss => {
89 write!(
90 f,
91 "accumulator format must preserve weight fractional precision"
92 )
93 }
94 Self::AccumulatorRangeLoss => {
95 write!(f, "accumulator format must cover the full weight range")
96 }
97 }
98 }
99}
100
101impl Error for QFormatError {}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub struct QFormatMixed {
105 pub weight_fmt: QFormat,
106 pub accum_fmt: QFormat,
107 pub scale_per_tensor: bool,
108}
109
110impl QFormatMixed {
111 pub fn q8_8_q16_16() -> Self {
112 Self {
113 weight_fmt: QFormat::q8_8(),
114 accum_fmt: QFormat::q16_16(),
115 scale_per_tensor: true,
116 }
117 }
118
119 pub fn new(
120 weight_fmt: QFormat,
121 accum_fmt: QFormat,
122 scale_per_tensor: bool,
123 ) -> Result<Self, QFormatError> {
124 if accum_fmt.total_bits() < weight_fmt.total_bits() {
125 return Err(QFormatError::AccumulatorNarrower);
126 }
127 if accum_fmt.fraction_bits < weight_fmt.fraction_bits {
128 return Err(QFormatError::AccumulatorFractionLoss);
129 }
130 if accum_fmt.min_value() > weight_fmt.min_value()
131 || accum_fmt.max_value() < weight_fmt.max_value()
132 {
133 return Err(QFormatError::AccumulatorRangeLoss);
134 }
135 Ok(Self {
136 weight_fmt,
137 accum_fmt,
138 scale_per_tensor,
139 })
140 }
141
142 pub fn accumulator_guard_bits(self) -> u8 {
143 self.accum_fmt.total_bits() - self.weight_fmt.total_bits()
144 }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub struct BlockFloatingMode {
149 pub mantissa_bits: u8,
150 pub exponent_bits: u8,
151 pub block_size: usize,
152}
153
154impl BlockFloatingMode {
155 pub fn new(
156 mantissa_bits: u8,
157 exponent_bits: u8,
158 block_size: usize,
159 ) -> Result<Self, BlockFloatingError> {
160 if mantissa_bits < 2 {
161 return Err(BlockFloatingError::MantissaTooNarrow);
162 }
163 if exponent_bits == 0 || exponent_bits > 7 {
164 return Err(BlockFloatingError::InvalidExponentBits);
165 }
166 if block_size == 0 {
167 return Err(BlockFloatingError::EmptyBlock);
168 }
169 Ok(Self {
170 mantissa_bits,
171 exponent_bits,
172 block_size,
173 })
174 }
175
176 pub fn bfp16_e3_x32() -> Self {
177 Self {
178 mantissa_bits: 16,
179 exponent_bits: 3,
180 block_size: 32,
181 }
182 }
183
184 pub fn exponent_bias(self) -> i32 {
185 (1_i32 << (self.exponent_bits - 1)) - 1
186 }
187
188 pub fn min_exponent(self) -> i32 {
189 -self.exponent_bias()
190 }
191
192 pub fn max_exponent(self) -> i32 {
193 ((1_i32 << self.exponent_bits) - 1) - self.exponent_bias()
194 }
195
196 pub fn mantissa_range(self) -> i128 {
197 (1_i128 << (self.mantissa_bits - 1)) - 1
198 }
199
200 pub fn exponent_code_max(self) -> u8 {
201 ((1_u16 << self.exponent_bits) - 1) as u8
202 }
203
204 pub fn block_exponent_count(self, parameter_count: usize) -> Result<usize, BlockFloatingError> {
205 if parameter_count == 0 {
206 return Ok(0);
207 }
208 parameter_count
209 .checked_add(self.block_size - 1)
210 .map(|value| value / self.block_size)
211 .ok_or(BlockFloatingError::ParameterCountOverflow)
212 }
213
214 pub fn block_exponent_layout(
215 self,
216 parameter_count: usize,
217 ) -> Result<BlockExponentLayout, BlockFloatingError> {
218 Ok(BlockExponentLayout {
219 parameter_count,
220 block_size: self.block_size,
221 exponent_count: self.block_exponent_count(parameter_count)?,
222 last_block_size: if parameter_count == 0 {
223 0
224 } else {
225 let remainder = parameter_count % self.block_size;
226 if remainder == 0 {
227 self.block_size
228 } else {
229 remainder
230 }
231 },
232 })
233 }
234
235 pub fn validate_exponent_count(
236 self,
237 parameter_count: usize,
238 exponent_count: usize,
239 ) -> Result<(), BlockFloatingError> {
240 let expected = self.block_exponent_count(parameter_count)?;
241 if exponent_count != expected {
242 return Err(BlockFloatingError::ExponentCountMismatch {
243 expected,
244 actual: exponent_count,
245 });
246 }
247 Ok(())
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub struct BlockExponentLayout {
253 pub parameter_count: usize,
254 pub block_size: usize,
255 pub exponent_count: usize,
256 pub last_block_size: usize,
257}
258
259#[derive(Debug, Clone, PartialEq, Eq)]
260pub enum BlockFloatingError {
261 MantissaTooNarrow,
262 InvalidExponentBits,
263 EmptyBlock,
264 ParameterCountOverflow,
265 ExponentCountMismatch { expected: usize, actual: usize },
266}
267
268impl fmt::Display for BlockFloatingError {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 match self {
271 Self::MantissaTooNarrow => write!(f, "mantissa bits must be at least 2"),
272 Self::InvalidExponentBits => write!(f, "exponent bits must be in 1..=7"),
273 Self::EmptyBlock => write!(f, "block size must be positive"),
274 Self::ParameterCountOverflow => write!(f, "parameter count overflows block layout"),
275 Self::ExponentCountMismatch { expected, actual } => {
276 write!(
277 f,
278 "exponent count mismatch: expected {expected}, got {actual}"
279 )
280 }
281 }
282 }
283}
284
285impl Error for BlockFloatingError {}
286
287#[derive(Debug, Clone, PartialEq, Eq)]
288pub struct MixedDenseResult {
289 pub outputs_q1616: Vec<i32>,
290 pub overflow: bool,
291 pub overflow_count: usize,
292 pub underflow_count: usize,
293 pub abs_bounds_q1616: Vec<i64>,
294}
295
296#[derive(Debug, Clone, Copy, PartialEq, Eq)]
297pub struct PrecisionTrapReport {
298 pub output_count: usize,
299 pub overflow: bool,
300 pub overflow_count: usize,
301 pub underflow: bool,
302 pub underflow_count: usize,
303 pub saturated_min_count: usize,
304 pub saturated_max_count: usize,
305}
306
307impl PrecisionTrapReport {
308 pub fn from_q1616(
309 outputs_q1616: &[i32],
310 overflow_count: usize,
311 underflow_count: usize,
312 ) -> Self {
313 let saturated_min_count = outputs_q1616
314 .iter()
315 .filter(|&&value| value == i32::MIN)
316 .count();
317 let saturated_max_count = outputs_q1616
318 .iter()
319 .filter(|&&value| value == i32::MAX)
320 .count();
321 Self {
322 output_count: outputs_q1616.len(),
323 overflow: overflow_count > 0,
324 overflow_count,
325 underflow: underflow_count > 0,
326 underflow_count,
327 saturated_min_count,
328 saturated_max_count,
329 }
330 }
331}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub struct PrecisionEnvelopeReport {
335 pub output_count: usize,
336 pub overflow: bool,
337 pub overflow_count: usize,
338 pub underflow: bool,
339 pub underflow_count: usize,
340 pub observed_overflow_free: bool,
341 pub observed_underflow_free: bool,
342 pub conservative_overflow_free: bool,
343 pub max_abs_output_q1616: i64,
344 pub max_abs_bound_q1616: i64,
345 pub conservative_safe_bound_q1616: i64,
346 pub min_headroom_q1616: i64,
347 pub required_total_bits_q1616: u8,
348 pub required_integer_bits_q1616: u8,
349 pub width_headroom_bits_q1616: i16,
350 pub saturation_required: bool,
351 pub static_overflow_proven_safe: bool,
352}
353
354impl MixedDenseResult {
355 pub fn precision_trap_report(&self) -> PrecisionTrapReport {
356 PrecisionTrapReport::from_q1616(
357 &self.outputs_q1616,
358 self.overflow_count,
359 self.underflow_count,
360 )
361 }
362
363 pub fn precision_envelope_report(&self) -> PrecisionEnvelopeReport {
364 let max_abs_output_q1616 = self
365 .outputs_q1616
366 .iter()
367 .map(|&value| abs_i32_to_i64(value))
368 .max()
369 .unwrap_or(0);
370 let max_abs_bound_q1616 = self.abs_bounds_q1616.iter().copied().max().unwrap_or(0);
371 let conservative_safe_bound_q1616 = i64::from(i32::MAX);
372 let min_headroom_q1616 = conservative_safe_bound_q1616.saturating_sub(max_abs_bound_q1616);
373 let required_total_bits_q1616 = required_signed_total_bits(max_abs_bound_q1616);
374 let required_integer_bits_q1616 = required_integer_bits_q1616(required_total_bits_q1616);
375 let width_headroom_bits_q1616 = 32_i16 - i16::from(required_total_bits_q1616);
376 let saturation_required = required_total_bits_q1616 > 32;
377 PrecisionEnvelopeReport {
378 output_count: self.outputs_q1616.len(),
379 overflow: self.overflow,
380 overflow_count: self.overflow_count,
381 underflow: self.underflow_count > 0,
382 underflow_count: self.underflow_count,
383 observed_overflow_free: self.overflow_count == 0,
384 observed_underflow_free: self.underflow_count == 0,
385 conservative_overflow_free: max_abs_bound_q1616 <= conservative_safe_bound_q1616,
386 max_abs_output_q1616,
387 max_abs_bound_q1616,
388 conservative_safe_bound_q1616,
389 min_headroom_q1616,
390 required_total_bits_q1616,
391 required_integer_bits_q1616,
392 width_headroom_bits_q1616,
393 saturation_required,
394 static_overflow_proven_safe: !saturation_required,
395 }
396 }
397}
398
399fn required_signed_total_bits(abs_bound_q1616: i64) -> u8 {
400 if abs_bound_q1616 <= 0 {
401 return 1;
402 }
403 (64 - (abs_bound_q1616 as u64).leading_zeros()) as u8 + 1
404}
405
406fn required_integer_bits_q1616(required_total_bits_q1616: u8) -> u8 {
407 required_total_bits_q1616.saturating_sub(16).max(1)
408}
409
410fn abs_i32_to_i64(value: i32) -> i64 {
411 if value == i32::MIN {
412 i64::from(i32::MAX) + 1
413 } else {
414 i64::from(value.abs())
415 }
416}
417
418fn i128_to_i64_saturating(value: i128) -> i64 {
419 if value > i128::from(i64::MAX) {
420 i64::MAX
421 } else if value < i128::from(i64::MIN) {
422 i64::MIN
423 } else {
424 value as i64
425 }
426}
427
428#[derive(Debug, Clone, PartialEq, Eq)]
429pub enum MixedDenseError {
430 EmptyShape,
431 ShapeOverflow,
432 WeightLengthMismatch { expected: usize, actual: usize },
433 InputLengthMismatch { expected: usize, actual: usize },
434}
435
436impl fmt::Display for MixedDenseError {
437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438 match self {
439 Self::EmptyShape => write!(f, "dense shape must have positive inputs and outputs"),
440 Self::ShapeOverflow => write!(f, "dense shape overflows addressable memory"),
441 Self::WeightLengthMismatch { expected, actual } => {
442 write!(
443 f,
444 "weight length mismatch: expected {expected}, got {actual}"
445 )
446 }
447 Self::InputLengthMismatch { expected, actual } => {
448 write!(
449 f,
450 "input length mismatch: expected {expected}, got {actual}"
451 )
452 }
453 }
454 }
455}
456
457impl Error for MixedDenseError {}
458
459#[derive(Debug, Clone, PartialEq, Eq)]
460pub enum BlockFloatingDenseError {
461 EmptyShape,
462 ShapeOverflow,
463 MantissaLengthMismatch { expected: usize, actual: usize },
464 ExponentLengthMismatch { expected: usize, actual: usize },
465 InputLengthMismatch { expected: usize, actual: usize },
466 MantissaOutOfRange { index: usize, value: i16 },
467 ExponentOutOfRange { index: usize, value: u8 },
468}
469
470impl fmt::Display for BlockFloatingDenseError {
471 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
472 match self {
473 Self::EmptyShape => write!(f, "dense shape must have positive inputs and outputs"),
474 Self::ShapeOverflow => write!(f, "dense shape overflows addressable memory"),
475 Self::MantissaLengthMismatch { expected, actual } => {
476 write!(
477 f,
478 "mantissa length mismatch: expected {expected}, got {actual}"
479 )
480 }
481 Self::ExponentLengthMismatch { expected, actual } => {
482 write!(
483 f,
484 "exponent length mismatch: expected {expected}, got {actual}"
485 )
486 }
487 Self::InputLengthMismatch { expected, actual } => {
488 write!(
489 f,
490 "input length mismatch: expected {expected}, got {actual}"
491 )
492 }
493 Self::MantissaOutOfRange { index, value } => {
494 write!(
495 f,
496 "mantissa at index {index} exceeds configured range: {value}"
497 )
498 }
499 Self::ExponentOutOfRange { index, value } => {
500 write!(
501 f,
502 "exponent at index {index} exceeds configured range: {value}"
503 )
504 }
505 }
506 }
507}
508
509impl Error for BlockFloatingDenseError {}
510
511pub fn mixed_dense_q88_q1616(
512 weights_q88: &[i16],
513 inputs_q1616: &[i32],
514 n_outputs: usize,
515 n_inputs: usize,
516) -> Result<MixedDenseResult, MixedDenseError> {
517 if n_inputs == 0 || n_outputs == 0 {
518 return Err(MixedDenseError::EmptyShape);
519 }
520 let expected_weights = n_outputs
521 .checked_mul(n_inputs)
522 .ok_or(MixedDenseError::ShapeOverflow)?;
523 if weights_q88.len() != expected_weights {
524 return Err(MixedDenseError::WeightLengthMismatch {
525 expected: expected_weights,
526 actual: weights_q88.len(),
527 });
528 }
529 if inputs_q1616.len() != n_inputs {
530 return Err(MixedDenseError::InputLengthMismatch {
531 expected: n_inputs,
532 actual: inputs_q1616.len(),
533 });
534 }
535
536 let mut outputs_q1616 = Vec::with_capacity(n_outputs);
537 let mut abs_bounds_q1616 = Vec::with_capacity(n_outputs);
538 let mut overflow_count = 0_usize;
539 let mut underflow_count = 0_usize;
540 for output_idx in 0..n_outputs {
541 let mut sum: i128 = 0;
542 let mut abs_bound: i128 = 0;
543 let row_start = output_idx * n_inputs;
544 for input_idx in 0..n_inputs {
545 let weight = i128::from(weights_q88[row_start + input_idx]);
546 let input = i128::from(inputs_q1616[input_idx]);
547 sum += weight * input;
548 abs_bound += weight.abs() * input.abs();
549 }
550 let scaled = sum >> 8;
551 let scaled_bound = (abs_bound + ((1_i128 << 8) - 1)) >> 8;
552 abs_bounds_q1616.push(i128_to_i64_saturating(scaled_bound));
553 if scaled > i128::from(i32::MAX) {
554 outputs_q1616.push(i32::MAX);
555 overflow_count += 1;
556 } else if scaled < i128::from(i32::MIN) {
557 outputs_q1616.push(i32::MIN);
558 overflow_count += 1;
559 } else {
560 if sum != 0 && scaled == 0 {
561 underflow_count += 1;
562 }
563 outputs_q1616.push(scaled as i32);
564 }
565 }
566
567 Ok(MixedDenseResult {
568 outputs_q1616,
569 overflow: overflow_count > 0,
570 overflow_count,
571 underflow_count,
572 abs_bounds_q1616,
573 })
574}
575
576#[derive(Debug, Clone, PartialEq, Eq)]
581pub struct MixedDenseBatchResult {
582 pub outputs_q1616: Vec<i32>,
584 pub overflow: Vec<bool>,
586 pub underflow: Vec<bool>,
588}
589
590pub fn mixed_dense_forward_batch_q88_q1616(
598 weights_q88: &[i16],
599 inputs_q1616: &[i32],
600 n_outputs: usize,
601 n_inputs: usize,
602) -> Result<MixedDenseBatchResult, MixedDenseError> {
603 if n_inputs == 0 || n_outputs == 0 {
604 return Err(MixedDenseError::EmptyShape);
605 }
606 let expected_weights = n_outputs
607 .checked_mul(n_inputs)
608 .ok_or(MixedDenseError::ShapeOverflow)?;
609 if weights_q88.len() != expected_weights {
610 return Err(MixedDenseError::WeightLengthMismatch {
611 expected: expected_weights,
612 actual: weights_q88.len(),
613 });
614 }
615 if inputs_q1616.is_empty() || !inputs_q1616.len().is_multiple_of(n_inputs) {
616 return Err(MixedDenseError::InputLengthMismatch {
617 expected: n_inputs,
618 actual: inputs_q1616.len(),
619 });
620 }
621
622 let n_batch = inputs_q1616.len() / n_inputs;
623 let count = n_batch * n_outputs;
624 let mut outputs_q1616 = Vec::with_capacity(count);
625 let mut overflow = Vec::with_capacity(count);
626 let mut underflow = Vec::with_capacity(count);
627 for batch_idx in 0..n_batch {
628 let input_row = &inputs_q1616[batch_idx * n_inputs..(batch_idx + 1) * n_inputs];
629 for output_idx in 0..n_outputs {
630 let weight_row = &weights_q88[output_idx * n_inputs..(output_idx + 1) * n_inputs];
631 let mut sum: i128 = 0;
632 for input_idx in 0..n_inputs {
633 sum += i128::from(weight_row[input_idx]) * i128::from(input_row[input_idx]);
634 }
635 let scaled = sum >> 8;
636 if scaled > i128::from(i32::MAX) {
637 outputs_q1616.push(i32::MAX);
638 overflow.push(true);
639 underflow.push(false);
640 } else if scaled < i128::from(i32::MIN) {
641 outputs_q1616.push(i32::MIN);
642 overflow.push(true);
643 underflow.push(false);
644 } else {
645 outputs_q1616.push(scaled as i32);
646 overflow.push(false);
647 underflow.push(sum != 0 && scaled == 0);
648 }
649 }
650 }
651 Ok(MixedDenseBatchResult {
652 outputs_q1616,
653 overflow,
654 underflow,
655 })
656}
657
658pub fn block_floating_dense_q16(
659 mantissas: &[i16],
660 exponents: &[u8],
661 inputs_q1616: &[i32],
662 n_outputs: usize,
663 n_inputs: usize,
664 mode: BlockFloatingMode,
665) -> Result<MixedDenseResult, BlockFloatingDenseError> {
666 if n_inputs == 0 || n_outputs == 0 {
667 return Err(BlockFloatingDenseError::EmptyShape);
668 }
669 let expected_weights = n_outputs
670 .checked_mul(n_inputs)
671 .ok_or(BlockFloatingDenseError::ShapeOverflow)?;
672 let expected_blocks = mode
673 .block_exponent_count(expected_weights)
674 .map_err(|_| BlockFloatingDenseError::ShapeOverflow)?;
675
676 if mantissas.len() != expected_weights {
677 return Err(BlockFloatingDenseError::MantissaLengthMismatch {
678 expected: expected_weights,
679 actual: mantissas.len(),
680 });
681 }
682 if exponents.len() != expected_blocks {
683 return Err(BlockFloatingDenseError::ExponentLengthMismatch {
684 expected: expected_blocks,
685 actual: exponents.len(),
686 });
687 }
688 if inputs_q1616.len() != n_inputs {
689 return Err(BlockFloatingDenseError::InputLengthMismatch {
690 expected: n_inputs,
691 actual: inputs_q1616.len(),
692 });
693 }
694
695 let mantissa_range = mode.mantissa_range();
696 for (index, &mantissa) in mantissas.iter().enumerate() {
697 if i128::from(mantissa).abs() > mantissa_range {
698 return Err(BlockFloatingDenseError::MantissaOutOfRange {
699 index,
700 value: mantissa,
701 });
702 }
703 }
704 let exponent_code_max = mode.exponent_code_max();
705 for (index, &exponent) in exponents.iter().enumerate() {
706 if exponent > exponent_code_max {
707 return Err(BlockFloatingDenseError::ExponentOutOfRange {
708 index,
709 value: exponent,
710 });
711 }
712 }
713
714 let mut outputs_q1616 = Vec::with_capacity(n_outputs);
715 let mut abs_bounds_q1616 = Vec::with_capacity(n_outputs);
716 let mut overflow_count = 0_usize;
717 let mut underflow_count = 0_usize;
718 for output_idx in 0..n_outputs {
719 let mut sum: i128 = 0;
720 let mut abs_bound: i128 = 0;
721 let mut dropped_sub_lsb_product = false;
722 let row_start = output_idx * n_inputs;
723 for input_idx in 0..n_inputs {
724 let linear_idx = row_start + input_idx;
725 let block_idx = linear_idx / mode.block_size;
726 let product = i128::from(mantissas[linear_idx]) * i128::from(inputs_q1616[input_idx]);
727 let shift = i32::from(exponents[block_idx]) - mode.exponent_bias();
728 if shift >= 0 {
729 sum += product << shift;
730 abs_bound += product.abs() << shift;
731 } else {
732 sum += product >> (-shift);
733 let divisor_shift = -shift;
734 if product != 0 && (product >> divisor_shift) == 0 {
735 dropped_sub_lsb_product = true;
736 }
737 abs_bound += (product.abs() + ((1_i128 << divisor_shift) - 1)) >> divisor_shift;
738 }
739 }
740 abs_bounds_q1616.push(i128_to_i64_saturating(abs_bound));
741 if sum > i128::from(i32::MAX) {
742 outputs_q1616.push(i32::MAX);
743 overflow_count += 1;
744 } else if sum < i128::from(i32::MIN) {
745 outputs_q1616.push(i32::MIN);
746 overflow_count += 1;
747 } else {
748 if sum == 0 && dropped_sub_lsb_product {
749 underflow_count += 1;
750 }
751 outputs_q1616.push(sum as i32);
752 }
753 }
754
755 Ok(MixedDenseResult {
756 outputs_q1616,
757 overflow: overflow_count > 0,
758 overflow_count,
759 underflow_count,
760 abs_bounds_q1616,
761 })
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767
768 #[test]
769 fn mixed_dense_batch_matches_single_per_output() {
770 let weights = [256_i16, -128, 64, 512];
771 let inputs = [512_i32, 1024, 256, 768];
772 let batch = mixed_dense_forward_batch_q88_q1616(&weights, &inputs, 2, 2).unwrap();
773 assert_eq!(batch.outputs_q1616.len(), 4);
775 for batch_idx in 0..2 {
776 let row = &inputs[batch_idx * 2..batch_idx * 2 + 2];
777 let single = mixed_dense_q88_q1616(&weights, row, 2, 2).unwrap();
778 for output_idx in 0..2 {
779 assert_eq!(
780 batch.outputs_q1616[batch_idx * 2 + output_idx],
781 single.outputs_q1616[output_idx]
782 );
783 }
784 }
785 }
786
787 #[test]
788 fn mixed_dense_batch_floor_division_is_signed() {
789 let batch = mixed_dense_forward_batch_q88_q1616(&[1], &[-1], 1, 1).unwrap();
791 assert_eq!(batch.outputs_q1616, vec![-1]);
792 assert!(!batch.overflow[0]);
793 assert!(!batch.underflow[0]);
794 }
795
796 #[test]
797 fn mixed_dense_batch_flags_overflow_and_underflow() {
798 let weights = [i16::MAX; 4];
799 let inputs = [2_000_000_000_i32, 2_000_000_000, 1, 1];
800 let batch = mixed_dense_forward_batch_q88_q1616(&weights, &inputs, 1, 4).unwrap();
801 assert!(batch.overflow[0]);
802 assert_eq!(batch.outputs_q1616[0], i32::MAX);
803 let under = mixed_dense_forward_batch_q88_q1616(&[1], &[1], 1, 1).unwrap();
805 assert_eq!(under.outputs_q1616, vec![0]);
806 assert!(under.underflow[0]);
807 assert!(!under.overflow[0]);
808 }
809
810 #[test]
811 fn mixed_dense_batch_rejects_bad_shapes() {
812 assert_eq!(
813 mixed_dense_forward_batch_q88_q1616(&[1], &[1], 0, 1).unwrap_err(),
814 MixedDenseError::EmptyShape
815 );
816 assert!(matches!(
817 mixed_dense_forward_batch_q88_q1616(&[1, 1], &[1], 1, 1).unwrap_err(),
818 MixedDenseError::WeightLengthMismatch { .. }
819 ));
820 assert!(matches!(
821 mixed_dense_forward_batch_q88_q1616(&[1, 1], &[1, 1, 1], 1, 2).unwrap_err(),
822 MixedDenseError::InputLengthMismatch { .. }
823 ));
824 }
825
826 #[test]
827 fn qformat_mixed_default_matches_python_contract() {
828 let fmt = QFormatMixed::q8_8_q16_16();
829
830 assert_eq!(fmt.weight_fmt.label(), "Q8.8");
831 assert_eq!(fmt.accum_fmt.label(), "Q16.16");
832 assert_eq!(fmt.accumulator_guard_bits(), 16);
833 }
834
835 #[test]
836 fn rejects_accumulator_precision_loss() {
837 let result = QFormatMixed::new(
838 QFormat::new(8, 12).unwrap(),
839 QFormat::new(16, 8).unwrap(),
840 true,
841 );
842
843 assert_eq!(result.unwrap_err(), QFormatError::AccumulatorFractionLoss);
844 }
845
846 #[test]
847 fn mixed_dense_matches_manual_q88_q1616_codes() {
848 let weights = [128_i16, -64_i16, 256_i16, 32_i16];
849 let inputs = [32768_i32, -16384_i32];
850
851 let result = mixed_dense_q88_q1616(&weights, &inputs, 2, 2).unwrap();
852
853 assert_eq!(result.outputs_q1616, vec![20480, 30720]);
854 assert!(!result.overflow);
855 assert_eq!(result.overflow_count, 0);
856 assert_eq!(result.underflow_count, 0);
857 assert_eq!(result.abs_bounds_q1616, vec![20480, 34816]);
858
859 let envelope = result.precision_envelope_report();
860 assert!(envelope.observed_overflow_free);
861 assert!(envelope.observed_underflow_free);
862 assert!(envelope.conservative_overflow_free);
863 assert_eq!(envelope.max_abs_output_q1616, 30720);
864 assert_eq!(envelope.max_abs_bound_q1616, 34816);
865 assert_eq!(envelope.required_total_bits_q1616, 17);
866 assert_eq!(envelope.required_integer_bits_q1616, 1);
867 assert_eq!(envelope.width_headroom_bits_q1616, 15);
868 assert!(!envelope.saturation_required);
869 assert!(envelope.static_overflow_proven_safe);
870 }
871
872 #[test]
873 fn mixed_dense_negative_products_follow_arithmetic_shift() {
874 let result = mixed_dense_q88_q1616(&[128_i16], &[-1_i32], 1, 1).unwrap();
875
876 assert_eq!(result.outputs_q1616, vec![-1]);
877 }
878
879 #[test]
880 fn mixed_dense_reports_sub_lsb_underflow() {
881 let result = mixed_dense_q88_q1616(&[1_i16], &[1_i32], 1, 1).unwrap();
882
883 assert_eq!(result.outputs_q1616, vec![0]);
884 assert_eq!(result.overflow_count, 0);
885 assert_eq!(result.underflow_count, 1);
886
887 let report = result.precision_trap_report();
888 assert!(!report.overflow);
889 assert!(report.underflow);
890 assert_eq!(report.underflow_count, 1);
891
892 let envelope = result.precision_envelope_report();
893 assert!(envelope.observed_overflow_free);
894 assert!(!envelope.observed_underflow_free);
895 }
896
897 #[test]
898 fn mixed_dense_saturates_overflow() {
899 let weights = [i16::MAX, i16::MAX];
900 let inputs = [i32::MAX, i32::MAX];
901
902 let result = mixed_dense_q88_q1616(&weights, &inputs, 1, 2).unwrap();
903
904 assert_eq!(result.outputs_q1616, vec![i32::MAX]);
905 assert!(result.overflow);
906 assert_eq!(result.overflow_count, 1);
907 assert_eq!(result.underflow_count, 0);
908
909 let report = result.precision_trap_report();
910 assert_eq!(report.output_count, 1);
911 assert!(report.overflow);
912 assert_eq!(report.overflow_count, 1);
913 assert!(!report.underflow);
914 assert_eq!(report.underflow_count, 0);
915 assert_eq!(report.saturated_max_count, 1);
916 assert_eq!(report.saturated_min_count, 0);
917
918 let envelope = result.precision_envelope_report();
919 assert!(!envelope.observed_overflow_free);
920 assert!(envelope.observed_underflow_free);
921 assert!(!envelope.conservative_overflow_free);
922 assert_eq!(envelope.output_count, 1);
923 assert_eq!(envelope.overflow_count, 1);
924 assert_eq!(envelope.underflow_count, 0);
925 assert!(envelope.max_abs_bound_q1616 > envelope.conservative_safe_bound_q1616);
926 assert!(envelope.saturation_required);
927 assert!(!envelope.static_overflow_proven_safe);
928 }
929
930 #[test]
931 fn mixed_dense_rejects_shape_mismatches() {
932 assert_eq!(
933 mixed_dense_q88_q1616(&[], &[1], 1, 0).unwrap_err(),
934 MixedDenseError::EmptyShape
935 );
936 assert_eq!(
937 mixed_dense_q88_q1616(&[1], &[1], 2, 1).unwrap_err(),
938 MixedDenseError::WeightLengthMismatch {
939 expected: 2,
940 actual: 1,
941 }
942 );
943 assert_eq!(
944 mixed_dense_q88_q1616(&[1, 2], &[1], 1, 2).unwrap_err(),
945 MixedDenseError::InputLengthMismatch {
946 expected: 2,
947 actual: 1,
948 }
949 );
950 }
951
952 #[test]
953 fn block_floating_mode_reports_full_exponent_range() {
954 let mode = BlockFloatingMode::new(8, 2, 2).unwrap();
955
956 assert_eq!(mode.exponent_bias(), 1);
957 assert_eq!(mode.min_exponent(), -1);
958 assert_eq!(mode.max_exponent(), 2);
959 assert_eq!(mode.exponent_code_max(), 3);
960 }
961
962 #[test]
963 fn block_floating_mode_computes_exponent_layout() {
964 let mode = BlockFloatingMode::new(16, 3, 32).unwrap();
965 let layout = mode.block_exponent_layout(65).unwrap();
966
967 assert_eq!(layout.parameter_count, 65);
968 assert_eq!(layout.block_size, 32);
969 assert_eq!(layout.exponent_count, 3);
970 assert_eq!(layout.last_block_size, 1);
971 assert_eq!(mode.block_exponent_count(0).unwrap(), 0);
972 assert_eq!(
973 mode.validate_exponent_count(65, 2).unwrap_err(),
974 BlockFloatingError::ExponentCountMismatch {
975 expected: 3,
976 actual: 2,
977 }
978 );
979 }
980
981 #[test]
982 fn block_floating_dense_matches_manual_shifted_products() {
983 let mode = BlockFloatingMode::new(16, 3, 2).unwrap();
984 let bias = mode.exponent_bias() as u8;
985 let mantissas = [2_i16, -4_i16, 8_i16, 16_i16];
986 let exponents = [bias, bias - 1];
987 let inputs = [32768_i32, -16384_i32];
988
989 let result = block_floating_dense_q16(&mantissas, &exponents, &inputs, 2, 2, mode).unwrap();
990
991 assert_eq!(result.outputs_q1616, vec![131072, 0]);
992 assert!(!result.overflow);
993 assert_eq!(result.underflow_count, 0);
994 assert_eq!(result.abs_bounds_q1616, vec![131072, 262144]);
995
996 let envelope = result.precision_envelope_report();
997 assert!(envelope.observed_overflow_free);
998 assert!(envelope.observed_underflow_free);
999 assert!(envelope.conservative_overflow_free);
1000 assert_eq!(envelope.max_abs_output_q1616, 131072);
1001 assert_eq!(envelope.max_abs_bound_q1616, 262144);
1002 }
1003
1004 #[test]
1005 fn block_floating_dense_seeded_exponent_edges_match_manual_q1616_codes() {
1006 let mode = BlockFloatingMode::new(16, 3, 2).unwrap();
1007 let mantissas = [
1008 1_i16,
1009 -2_i16,
1010 i16::MAX,
1011 -i16::MAX,
1012 -3_i16,
1013 4_i16,
1014 -i16::MAX,
1015 i16::MAX,
1016 ];
1017 let exponents = [
1018 0_u8,
1019 mode.exponent_code_max(),
1020 0_u8,
1021 mode.exponent_code_max(),
1022 ];
1023 let inputs = [32768_i32, -16384_i32, 1_i32, -1_i32];
1024
1025 let result = block_floating_dense_q16(&mantissas, &exponents, &inputs, 2, 4, mode)
1026 .expect("seeded exponent-edge dimensions are valid");
1027
1028 assert_eq!(result.outputs_q1616, vec![1_056_736, -1_069_024]);
1029 assert_eq!(result.overflow_count, 0);
1030 assert_eq!(result.underflow_count, 0);
1031 assert_eq!(result.abs_bounds_q1616, vec![1_056_736, 1_069_024]);
1032
1033 let envelope = result.precision_envelope_report();
1034 assert!(envelope.observed_overflow_free);
1035 assert!(envelope.observed_underflow_free);
1036 assert!(envelope.conservative_overflow_free);
1037 assert_eq!(envelope.max_abs_bound_q1616, 1_069_024);
1038 assert_eq!(envelope.min_headroom_q1616, 2_146_414_623);
1039 }
1040
1041 #[test]
1042 fn block_floating_dense_max_exponent_edge_saturates_and_reports_trap() {
1043 let mode = BlockFloatingMode::new(16, 3, 2).unwrap();
1044 let mantissas = [i16::MAX, i16::MAX];
1045 let exponents = [mode.exponent_code_max()];
1046 let inputs = [32767_i32 << 16, 32767_i32 << 16];
1047
1048 let result = block_floating_dense_q16(&mantissas, &exponents, &inputs, 1, 2, mode)
1049 .expect("max-exponent trap dimensions are valid");
1050
1051 assert_eq!(result.outputs_q1616, vec![i32::MAX]);
1052 assert!(result.overflow);
1053 assert_eq!(result.overflow_count, 1);
1054 assert_eq!(result.underflow_count, 0);
1055
1056 let report = result.precision_trap_report();
1057 assert!(report.overflow);
1058 assert_eq!(report.overflow_count, 1);
1059 assert!(!report.underflow);
1060 assert_eq!(report.saturated_max_count, 1);
1061
1062 let envelope = result.precision_envelope_report();
1063 assert!(!envelope.observed_overflow_free);
1064 assert!(envelope.observed_underflow_free);
1065 assert!(!envelope.conservative_overflow_free);
1066 assert!(envelope.max_abs_bound_q1616 > envelope.conservative_safe_bound_q1616);
1067 }
1068
1069 #[test]
1070 fn block_floating_dense_reports_sub_lsb_underflow() {
1071 let mode = BlockFloatingMode::new(16, 3, 1).unwrap();
1072 let result = block_floating_dense_q16(&[1_i16], &[0_u8], &[1_i32], 1, 1, mode).unwrap();
1073
1074 assert_eq!(result.outputs_q1616, vec![0]);
1075 assert_eq!(result.overflow_count, 0);
1076 assert_eq!(result.underflow_count, 1);
1077
1078 let report = result.precision_trap_report();
1079 assert!(!report.overflow);
1080 assert!(report.underflow);
1081 assert_eq!(report.underflow_count, 1);
1082
1083 let envelope = result.precision_envelope_report();
1084 assert!(envelope.observed_overflow_free);
1085 assert!(!envelope.observed_underflow_free);
1086 assert_eq!(envelope.max_abs_bound_q1616, 1);
1087 }
1088
1089 #[test]
1090 fn block_floating_dense_saturates_large_outputs() {
1091 let mode = BlockFloatingMode::bfp16_e3_x32();
1092 let mantissas = vec![i16::MAX; 64];
1093 let exponents = vec![mode.exponent_code_max(); 2];
1094 let inputs = vec![i32::MAX; 64];
1095
1096 let result =
1097 block_floating_dense_q16(&mantissas, &exponents, &inputs, 1, 64, mode).unwrap();
1098
1099 assert_eq!(result.outputs_q1616, vec![i32::MAX]);
1100 assert!(result.overflow);
1101 assert_eq!(result.overflow_count, 1);
1102 assert_eq!(result.underflow_count, 0);
1103
1104 let report = result.precision_trap_report();
1105 assert_eq!(report.output_count, 1);
1106 assert!(report.overflow);
1107 assert_eq!(report.overflow_count, 1);
1108 assert!(!report.underflow);
1109 assert_eq!(report.underflow_count, 0);
1110 assert_eq!(report.saturated_max_count, 1);
1111 assert_eq!(report.saturated_min_count, 0);
1112
1113 let envelope = result.precision_envelope_report();
1114 assert!(!envelope.observed_overflow_free);
1115 assert!(envelope.observed_underflow_free);
1116 assert!(!envelope.conservative_overflow_free);
1117 assert_eq!(envelope.output_count, 1);
1118 assert_eq!(envelope.overflow_count, 1);
1119 assert_eq!(envelope.underflow_count, 0);
1120 assert!(envelope.max_abs_bound_q1616 > envelope.conservative_safe_bound_q1616);
1121 }
1122
1123 #[test]
1124 fn block_floating_dense_rejects_invalid_lengths_and_ranges() {
1125 let mode = BlockFloatingMode::new(8, 2, 2).unwrap();
1126
1127 assert_eq!(
1128 block_floating_dense_q16(&[], &[1], &[1], 1, 0, mode).unwrap_err(),
1129 BlockFloatingDenseError::EmptyShape
1130 );
1131 assert_eq!(
1132 block_floating_dense_q16(&[1], &[1], &[1], 2, 1, mode).unwrap_err(),
1133 BlockFloatingDenseError::MantissaLengthMismatch {
1134 expected: 2,
1135 actual: 1,
1136 }
1137 );
1138 assert_eq!(
1139 block_floating_dense_q16(&[1, 2], &[], &[1, 2], 1, 2, mode).unwrap_err(),
1140 BlockFloatingDenseError::ExponentLengthMismatch {
1141 expected: 1,
1142 actual: 0,
1143 }
1144 );
1145 assert_eq!(
1146 block_floating_dense_q16(&[128, 0], &[1], &[1, 2], 1, 2, mode).unwrap_err(),
1147 BlockFloatingDenseError::MantissaOutOfRange {
1148 index: 0,
1149 value: 128,
1150 }
1151 );
1152 }
1153}
1154
1155#[cfg(test)]
1156mod block_floating_benchmark_contract_tests {
1157 use super::*;
1158
1159 const N_INPUTS: usize = 64;
1160 const N_OUTPUTS: usize = 32;
1161
1162 fn round_div_nearest_even(value: i32, divisor: i32) -> i16 {
1163 let sign = if value < 0 { -1 } else { 1 };
1164 let magnitude = value.abs();
1165 let quotient = magnitude / divisor;
1166 let remainder = magnitude % divisor;
1167 let rounded_magnitude = if remainder * 2 < divisor {
1168 quotient
1169 } else if remainder * 2 > divisor {
1170 quotient + 1
1171 } else if quotient % 2 == 0 {
1172 quotient
1173 } else {
1174 quotient + 1
1175 };
1176 (sign * rounded_magnitude) as i16
1177 }
1178
1179 #[test]
1180 fn block_floating_benchmark_matches_python_quantiser_envelope() {
1181 let mode = BlockFloatingMode::bfp16_e3_x32();
1182 let mantissas = (0..(N_INPUTS * N_OUTPUTS))
1183 .map(|idx| {
1184 let raw_weight_code = ((idx * 23 + 3) % 1025) as i32 - 512;
1185 round_div_nearest_even(raw_weight_code, 64)
1186 })
1187 .collect::<Vec<_>>();
1188 let exponents = vec![0_u8; (N_INPUTS * N_OUTPUTS + mode.block_size - 1) / mode.block_size];
1189 let inputs = (0..N_INPUTS)
1190 .map(|idx| (((idx * 19 + 5) % 257) as i32 - 128) << 8)
1191 .collect::<Vec<_>>();
1192
1193 let result =
1194 block_floating_dense_q16(&mantissas, &exponents, &inputs, N_OUTPUTS, N_INPUTS, mode)
1195 .expect("benchmark contract dimensions are valid");
1196 let envelope = result.precision_envelope_report();
1197
1198 assert_eq!(result.overflow_count, 0);
1199 assert_eq!(envelope.max_abs_bound_q1616, 610_816);
1200 assert!(envelope.conservative_overflow_free);
1201
1202 let saturating_mantissas = vec![16_384_i16; N_INPUTS * N_OUTPUTS];
1203 let saturating_exponents =
1204 vec![2_u8; (N_INPUTS * N_OUTPUTS + mode.block_size - 1) / mode.block_size];
1205 let saturating_inputs = vec![32767_i32 << 16; N_INPUTS];
1206 let saturating_result = block_floating_dense_q16(
1207 &saturating_mantissas,
1208 &saturating_exponents,
1209 &saturating_inputs,
1210 N_OUTPUTS,
1211 N_INPUTS,
1212 mode,
1213 )
1214 .expect("saturating benchmark contract dimensions are valid");
1215 let saturating_envelope = saturating_result.precision_envelope_report();
1216
1217 assert_eq!(saturating_result.overflow_count, N_OUTPUTS);
1218 assert_eq!(
1219 saturating_envelope.max_abs_bound_q1616,
1220 1_125_865_547_104_256
1221 );
1222 assert!(!saturating_envelope.conservative_overflow_free);
1223 }
1224}
1225#[cfg(test)]
1226mod mixed_dense_benchmark_contract_tests {
1227 use super::*;
1228
1229 #[test]
1230 fn mixed_dense_benchmark_contract_matches_python_envelope() {
1231 const N_INPUTS: usize = 64;
1232 const N_OUTPUTS: usize = 32;
1233
1234 let weights = (0..(N_INPUTS * N_OUTPUTS))
1235 .map(|idx| (((idx * 17 + 11) % 513) as i32 - 256) as i16)
1236 .collect::<Vec<_>>();
1237 let inputs = (0..N_INPUTS)
1238 .map(|idx| (((idx as i32 * 19 + 5) % 257) - 128) << 8)
1239 .collect::<Vec<_>>();
1240 let safe = mixed_dense_q88_q1616(&weights, &inputs, N_OUTPUTS, N_INPUTS)
1241 .expect("benchmark contract dimensions must be valid");
1242 let safe_envelope = safe.precision_envelope_report();
1243
1244 assert_eq!(safe.overflow_count, 0);
1245 assert_eq!(safe_envelope.max_abs_bound_q1616, 531_400);
1246 assert!(safe_envelope.conservative_overflow_free);
1247 assert_eq!(safe_envelope.min_headroom_q1616, 2_146_952_247);
1248 assert_eq!(safe_envelope.required_total_bits_q1616, 21);
1249 assert_eq!(safe_envelope.required_integer_bits_q1616, 5);
1250 assert_eq!(safe_envelope.width_headroom_bits_q1616, 11);
1251 assert!(!safe_envelope.saturation_required);
1252
1253 let probe_weights = vec![127_i16 << 8; N_INPUTS * N_OUTPUTS];
1254 let probe_inputs = vec![32767_i32 << 16; N_INPUTS];
1255 let probe = mixed_dense_q88_q1616(&probe_weights, &probe_inputs, N_OUTPUTS, N_INPUTS)
1256 .expect("saturating probe dimensions must be valid");
1257 let probe_envelope = probe.precision_envelope_report();
1258
1259 assert_eq!(probe.overflow_count, N_OUTPUTS);
1260 assert_eq!(probe_envelope.max_abs_bound_q1616, 17_454_214_414_336);
1261 assert!(!probe_envelope.conservative_overflow_free);
1262 assert_eq!(probe_envelope.required_total_bits_q1616, 45);
1263 assert_eq!(probe_envelope.required_integer_bits_q1616, 29);
1264 assert_eq!(probe_envelope.width_headroom_bits_q1616, -13);
1265 assert!(probe_envelope.saturation_required);
1266 }
1267}