1use rand::{Rng, RngExt};
15
16use crate::simd;
17
18#[derive(Clone, Debug)]
20pub struct BitStreamTensor {
21 pub data: Vec<u64>,
23 pub length: usize,
25}
26
27impl BitStreamTensor {
28 pub fn from_words(data: Vec<u64>, length: usize) -> Self {
30 assert!(length > 0, "bitstream length must be > 0");
31 Self { data, length }
32 }
33
34 pub fn xor_inplace(&mut self, other: &BitStreamTensor) {
38 assert_eq!(
39 self.length, other.length,
40 "Bitstream lengths must match for XOR."
41 );
42 for (a, b) in self.data.iter_mut().zip(other.data.iter()) {
43 *a ^= *b;
44 }
45 }
46
47 pub fn xor(&self, other: &BitStreamTensor) -> BitStreamTensor {
49 assert_eq!(
50 self.length, other.length,
51 "Bitstream lengths must match for XOR."
52 );
53 let data = self
54 .data
55 .iter()
56 .zip(other.data.iter())
57 .map(|(&a, &b)| a ^ b)
58 .collect();
59 BitStreamTensor {
60 data,
61 length: self.length,
62 }
63 }
64
65 pub fn rotate_right(&mut self, shift: usize) {
69 if self.length == 0 || shift.is_multiple_of(self.length) {
70 return;
71 }
72 let mut bits = unpack(self);
73 bits.rotate_right(shift % self.length);
74 *self = pack(&bits);
75 }
76
77 pub fn hamming_distance(&self, other: &BitStreamTensor) -> f32 {
79 assert_eq!(
80 self.length, other.length,
81 "Bitstream lengths must match for Hamming distance."
82 );
83 let xor_count: u64 = crate::simd::fused_xor_popcount_dispatch(&self.data, &other.data);
84 xor_count as f32 / self.length as f32
85 }
86
87 pub fn bundle(vectors: &[&BitStreamTensor]) -> BitStreamTensor {
92 assert!(!vectors.is_empty(), "Cannot bundle zero vectors.");
93 let length = vectors[0].length;
94 let words = vectors[0].data.len();
95
96 if vectors.len() == 1 {
97 return vectors[0].clone();
98 }
99
100 let mut data = vec![0u64; words];
104
105 if vectors.len() == 3 {
106 for (i, item) in data.iter_mut().enumerate().take(words) {
107 let a = vectors[0].data[i];
108 let b = vectors[1].data[i];
109 let c = vectors[2].data[i];
110 *item = (a & b) | (b & c) | (a & c);
111 }
112 } else {
113 let threshold = vectors.len() / 2;
115 for (i, item) in data.iter_mut().enumerate().take(words) {
116 for bit in 0..64 {
117 let mut count = 0;
118 for v in vectors {
119 if (v.data[i] >> bit) & 1 == 1 {
120 count += 1;
121 }
122 }
123 if count > threshold {
124 *item |= 1u64 << bit;
125 }
126 }
127 }
128 }
129
130 BitStreamTensor { data, length }
131 }
132}
133
134pub fn pack(bits: &[u8]) -> BitStreamTensor {
136 let length = bits.len();
137 let words = length.div_ceil(64);
138 let mut data = vec![0_u64; words];
139
140 for (idx, bit) in bits.iter().copied().enumerate() {
141 if bit != 0 {
142 data[idx / 64] |= 1_u64 << (idx % 64);
143 }
144 }
145
146 BitStreamTensor { data, length }
147}
148
149pub fn pack_fast(bits: &[u8]) -> BitStreamTensor {
151 let length = bits.len();
152 let words = length.div_ceil(64);
153 let mut data = vec![0_u64; words];
154
155 for (word_idx, word) in data.iter_mut().enumerate() {
156 let base = word_idx * 64;
157 let chunk = &bits[base..std::cmp::min(base + 64, length)];
158
159 for (byte_idx, byte_chunk) in chunk.chunks(8).enumerate() {
160 let mut packed_byte: u8 = 0;
161 for (bit_idx, &bit) in byte_chunk.iter().enumerate() {
162 packed_byte |= u8::from(bit != 0) << bit_idx;
163 }
164 *word |= (packed_byte as u64) << (byte_idx * 8);
165 }
166 }
167
168 BitStreamTensor { data, length }
169}
170
171pub fn unpack(tensor: &BitStreamTensor) -> Vec<u8> {
173 let mut bits = vec![0_u8; tensor.length];
174
175 for (idx, bit) in bits.iter_mut().enumerate().take(tensor.length) {
176 let word = tensor.data[idx / 64];
177 *bit = ((word >> (idx % 64)) & 1) as u8;
178 }
179
180 bits
181}
182
183pub fn bitwise_and(a: &BitStreamTensor, b: &BitStreamTensor) -> BitStreamTensor {
185 assert_eq!(
186 a.length, b.length,
187 "Bitstream lengths must match for bitwise AND."
188 );
189 assert_eq!(
190 a.data.len(),
191 b.data.len(),
192 "Packed bitstream shapes must match for bitwise AND."
193 );
194
195 let data = a
196 .data
197 .iter()
198 .zip(b.data.iter())
199 .map(|(lhs, rhs)| lhs & rhs)
200 .collect();
201
202 BitStreamTensor {
203 data,
204 length: a.length,
205 }
206}
207
208pub fn swar_popcount_word(mut x: u64) -> u64 {
210 x = x.wrapping_sub((x >> 1) & 0x5555_5555_5555_5555);
211 x = (x & 0x3333_3333_3333_3333) + ((x >> 2) & 0x3333_3333_3333_3333);
212 x = (x + (x >> 4)) & 0x0f0f_0f0f_0f0f_0f0f;
213 x.wrapping_mul(0x0101_0101_0101_0101) >> 56
214}
215
216pub fn popcount_words_portable(data: &[u64]) -> u64 {
218 data.iter().copied().map(swar_popcount_word).sum()
219}
220
221pub fn popcount(tensor: &BitStreamTensor) -> u64 {
223 popcount_words_portable(&tensor.data)
224}
225
226pub fn bernoulli_stream<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u8> {
230 let p = prob.clamp(0.0, 1.0);
231 let mut out = vec![0_u8; length];
232 for bit in &mut out {
233 *bit = if rng.random::<f64>() < p { 1 } else { 0 };
234 }
235 out
236}
237
238pub fn bernoulli_packed<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
243 let p = prob.clamp(0.0, 1.0);
244 let words = length.div_ceil(64);
245 let mut data = vec![0_u64; words];
246 for (word_idx, word) in data.iter_mut().enumerate() {
247 let bits_in_word = std::cmp::min(64, length.saturating_sub(word_idx * 64));
248 for bit in 0..bits_in_word {
249 if rng.random::<f64>() < p {
250 *word |= 1_u64 << bit;
251 }
252 }
253 }
254 data
255}
256
257pub fn bernoulli_packed_fast<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
271 let words = length.div_ceil(64);
272 if prob <= 0.0 {
273 return vec![0_u64; words];
274 }
275 if prob >= 1.0 {
276 let mut data = vec![u64::MAX; words];
277 let trailing = length % 64;
278 if trailing > 0 {
279 data[words - 1] = (1_u64 << trailing) - 1;
280 }
281 return data;
282 }
283 let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
284 let mut data = vec![0_u64; words];
285 let mut buf = [0_u8; 64];
286
287 for (word_idx, word) in data.iter_mut().enumerate() {
288 let bits_in_word = std::cmp::min(64, length.saturating_sub(word_idx * 64));
289 rng.fill(&mut buf[..bits_in_word]);
290 for (bit, &rb) in buf[..bits_in_word].iter().enumerate() {
291 if rb < threshold {
292 *word |= 1_u64 << bit;
293 }
294 }
295 }
296 data
297}
298
299pub fn bernoulli_packed_simd<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
304 let words = length.div_ceil(64);
305 if prob <= 0.0 {
306 return vec![0_u64; words];
307 }
308 if prob >= 1.0 {
309 let mut data = vec![u64::MAX; words];
310 let trailing = length % 64;
311 if trailing > 0 {
312 data[words - 1] = (1_u64 << trailing) - 1;
313 }
314 return data;
315 }
316 let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
317 let mut data = vec![0_u64; words];
318 let full_words = length / 64;
319 let mut buf = [0_u8; 1024];
320 let mut chunks = data[..full_words].chunks_exact_mut(16);
321
322 for w_chunk in chunks.by_ref() {
323 rng.fill(&mut buf);
324 crate::simd::bernoulli_compare_batch_1024(&buf, threshold, w_chunk);
325 }
326
327 for word in chunks.into_remainder() {
328 let mut small_buf = [0_u8; 64];
329 rng.fill(&mut small_buf);
330 *word = simd_bernoulli_compare_exposed(&small_buf, threshold);
331 }
332
333 if full_words < words {
334 let remaining = length - full_words * 64;
335 rng.fill(&mut buf[..remaining]);
336 let mut tail = 0_u64;
337 for (bit, &rb) in buf[..remaining].iter().enumerate() {
338 if rb < threshold {
339 tail |= 1_u64 << bit;
340 }
341 }
342 data[full_words] = tail;
343 }
344
345 data
346}
347
348pub fn encode_and_popcount<R: Rng + ?Sized>(
354 weight_words: &[u64],
355 prob: f64,
356 length: usize,
357 rng: &mut R,
358) -> u64 {
359 if prob <= 0.0 {
360 return 0;
361 }
362 if prob >= 1.0 {
363 let full_words = length / 64;
365 let mut total = 0_u64;
366 for w in weight_words.iter().take(full_words) {
367 total += w.count_ones() as u64;
368 }
369 let trailing = length % 64;
370 if trailing > 0 && full_words < weight_words.len() {
371 let mask = (1_u64 << trailing) - 1;
372 total += (weight_words[full_words] & mask).count_ones() as u64;
373 }
374 return total;
375 }
376 let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
377 let full_words = length / 64;
378 let mut total = 0_u64;
379 let mut buf = [0_u8; 1024]; let mut chunks = weight_words[..full_words].chunks_exact(16);
381
382 let mut encoded_batch = [0_u64; 16];
383 for w_chunk in chunks.by_ref() {
384 rng.fill(&mut buf);
385 crate::simd::bernoulli_compare_batch_1024(&buf, threshold, &mut encoded_batch);
386 for (i, &w_word) in w_chunk.iter().enumerate() {
387 total += (encoded_batch[i] & w_word).count_ones() as u64;
388 }
389 }
390
391 for &w_word in chunks.remainder() {
392 let mut small_buf = [0_u8; 64];
393 rng.fill(&mut small_buf);
394 let encoded = simd_bernoulli_compare_exposed(&small_buf, threshold);
395 total += (encoded & w_word).count_ones() as u64;
396 }
397
398 let remaining = length.saturating_sub(full_words * 64);
399 if remaining > 0 && full_words < weight_words.len() {
400 rng.fill(&mut buf[..remaining]);
401 let mut encoded = 0_u64;
402 for (bit, &rb) in buf[..remaining].iter().enumerate() {
403 if rb < threshold {
404 encoded |= 1_u64 << bit;
405 }
406 }
407 total += (encoded & weight_words[full_words]).count_ones() as u64;
408 }
409
410 total
411}
412
413#[inline]
415pub fn simd_bernoulli_compare_exposed(buf: &[u8], threshold: u8) -> u64 {
416 debug_assert!(buf.len() >= 64, "buffer must contain at least 64 bytes");
417
418 #[cfg(target_arch = "x86_64")]
419 {
420 if is_x86_feature_detected!("avx512bw") {
421 return unsafe { simd::avx512::bernoulli_compare_avx512(buf, threshold) };
423 }
424 if is_x86_feature_detected!("avx2") {
425 let lo = unsafe { simd::avx2::bernoulli_compare_avx2(&buf[0..32], threshold) };
427 let hi = unsafe { simd::avx2::bernoulli_compare_avx2(&buf[32..64], threshold) };
429 return (lo as u64) | ((hi as u64) << 32);
430 }
431 }
432
433 let mut mask = 0_u64;
434 for (bit, &rb) in buf.iter().take(64).enumerate() {
435 if rb < threshold {
436 mask |= 1_u64 << bit;
437 }
438 }
439 mask
440}
441
442pub fn encode_matrix_prob_to_packed<R: Rng + ?Sized>(
446 values: &[f64],
447 rows: usize,
448 cols: usize,
449 length: usize,
450 words: usize,
451 rng: &mut R,
452) -> Vec<Vec<u64>> {
453 let mut packed = Vec::with_capacity(rows * cols);
454 for value in values.iter().take(rows * cols) {
455 let mut row = bernoulli_packed_simd(*value, length, rng);
456 row.resize(words, 0);
457 packed.push(row);
458 }
459 packed
460}
461
462#[cfg(test)]
463mod tests {
464 use super::{
465 bernoulli_packed, bernoulli_packed_fast, bernoulli_packed_simd, bernoulli_stream,
466 bitwise_and, encode_and_popcount, pack, pack_fast, popcount, unpack,
467 };
468
469 #[test]
470 fn pack_unpack_roundtrip() {
471 let bits = vec![1, 0, 1, 1, 0, 1, 0, 0, 1];
472 let packed = pack(&bits);
473 let unpacked = unpack(&packed);
474 assert_eq!(bits, unpacked);
475 }
476
477 #[test]
478 fn pack_fast_matches_pack() {
479 let cases = [0_usize, 1, 7, 8, 9, 63, 64, 65, 127, 128, 256, 1025];
480 for length in cases {
481 let bits: Vec<u8> = (0..length).map(|i| ((i * 7 + 3) % 2) as u8).collect();
482 let slow = pack(&bits);
483 let fast = pack_fast(&bits);
484 assert_eq!(fast.length, slow.length);
485 assert_eq!(fast.data, slow.data, "Mismatch at length={length}");
486 }
487 }
488
489 #[test]
490 fn pack_fast_roundtrip() {
491 let bits: Vec<u8> = (0..2048).map(|i| ((i * 5 + 1) % 2) as u8).collect();
492 let packed = pack_fast(&bits);
493 let unpacked = unpack(&packed);
494 assert_eq!(bits, unpacked);
495 }
496
497 #[test]
498 fn and_and_popcount() {
499 let a = pack(&[1, 0, 1, 1, 0, 0, 1, 1]);
500 let b = pack(&[1, 1, 1, 0, 0, 1, 1, 0]);
501 let c = bitwise_and(&a, &b);
502 assert_eq!(unpack(&c), vec![1, 0, 1, 0, 0, 0, 1, 0]);
503 assert_eq!(popcount(&c), 3);
504 }
505
506 #[test]
507 fn bernoulli_packed_matches_stream_then_pack() {
508 use rand::SeedableRng;
509 use rand_chacha::ChaCha8Rng;
510
511 let prob = 0.35;
512 let length = 200;
513
514 let mut rng1 = ChaCha8Rng::seed_from_u64(999);
515 let stream = bernoulli_stream(prob, length, &mut rng1);
516 let packed_via_stream = pack(&stream).data;
517
518 let mut rng2 = ChaCha8Rng::seed_from_u64(999);
519 let packed_direct = bernoulli_packed(prob, length, &mut rng2);
520
521 assert_eq!(
522 packed_via_stream, packed_direct,
523 "bernoulli_packed must produce bit-identical output"
524 );
525 }
526
527 #[test]
528 fn bernoulli_packed_fast_statistics() {
529 use rand::SeedableRng;
530 use rand_chacha::ChaCha8Rng;
531
532 let prob = 0.35;
533 let length = 10_000;
534 let mut rng = ChaCha8Rng::seed_from_u64(42);
535 let packed = bernoulli_packed_fast(prob, length, &mut rng);
536 let count: u64 = packed.iter().map(|w| w.count_ones() as u64).sum();
537 let measured = count as f64 / length as f64;
538 assert!(
539 (measured - prob).abs() < 0.03,
540 "Expected ~{prob}, got {measured}"
541 );
542 }
543
544 #[test]
545 fn bernoulli_packed_fast_deterministic() {
546 use rand::SeedableRng;
547 use rand_chacha::ChaCha8Rng;
548
549 let mut rng1 = ChaCha8Rng::seed_from_u64(99);
550 let a = bernoulli_packed_fast(0.5, 512, &mut rng1);
551
552 let mut rng2 = ChaCha8Rng::seed_from_u64(99);
553 let b = bernoulli_packed_fast(0.5, 512, &mut rng2);
554
555 assert_eq!(a, b, "Same seed must produce identical output");
556 }
557
558 #[test]
559 fn bernoulli_packed_simd_statistics() {
560 use rand::SeedableRng;
561 use rand_chacha::ChaCha8Rng;
562
563 let prob = 0.35;
564 let length = 10_000;
565 let mut rng = ChaCha8Rng::seed_from_u64(1337);
566 let packed = bernoulli_packed_simd(prob, length, &mut rng);
567 let count: u64 = packed.iter().map(|w| w.count_ones() as u64).sum();
568 let measured = count as f64 / length as f64;
569 assert!(
570 (measured - prob).abs() < 0.03,
571 "Expected ~{prob}, got {measured}"
572 );
573 }
574
575 #[test]
576 fn bernoulli_packed_simd_deterministic() {
577 use rand::SeedableRng;
578 use rand_chacha::ChaCha8Rng;
579
580 let mut rng1 = ChaCha8Rng::seed_from_u64(2026);
581 let a = bernoulli_packed_simd(0.5, 1024, &mut rng1);
582
583 let mut rng2 = ChaCha8Rng::seed_from_u64(2026);
584 let b = bernoulli_packed_simd(0.5, 1024, &mut rng2);
585
586 assert_eq!(a, b, "Same seed must produce identical output");
587 }
588
589 #[test]
590 fn encode_and_popcount_matches_materialized() {
591 use rand::SeedableRng;
592 use rand_xoshiro::Xoshiro256PlusPlus;
593
594 let prob = 0.41;
595 let lengths = [63_usize, 64, 65, 1003, 1024];
596 for length in lengths {
597 let words = length.div_ceil(64);
598 let weights: Vec<u64> = (0..words)
599 .map(|i| (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xA5A5_A5A5_5A5A_5A5A)
600 .collect();
601
602 let mut rng1 = Xoshiro256PlusPlus::seed_from_u64(2026);
603 let fused = encode_and_popcount(&weights, prob, length, &mut rng1);
604
605 let mut rng2 = Xoshiro256PlusPlus::seed_from_u64(2026);
606 let encoded = bernoulli_packed_simd(prob, length, &mut rng2);
607 let expected: u64 = encoded
608 .iter()
609 .zip(weights.iter())
610 .map(|(&e, &w)| (e & w).count_ones() as u64)
611 .sum();
612
613 assert_eq!(fused, expected, "Mismatch at length={length}");
614 }
615 }
616}