1use rand::{Rng, RngExt};
14
15use crate::simd;
16
17#[derive(Clone, Debug)]
19pub struct BitStreamTensor {
20 pub data: Vec<u64>,
22 pub length: usize,
24}
25
26impl BitStreamTensor {
27 pub fn from_words(data: Vec<u64>, length: usize) -> Self {
29 assert!(length > 0, "bitstream length must be > 0");
30 Self { data, length }
31 }
32
33 pub fn xor_inplace(&mut self, other: &BitStreamTensor) {
37 assert_eq!(
38 self.length, other.length,
39 "Bitstream lengths must match for XOR."
40 );
41 for (a, b) in self.data.iter_mut().zip(other.data.iter()) {
42 *a ^= *b;
43 }
44 }
45
46 pub fn xor(&self, other: &BitStreamTensor) -> BitStreamTensor {
48 assert_eq!(
49 self.length, other.length,
50 "Bitstream lengths must match for XOR."
51 );
52 let data = self
53 .data
54 .iter()
55 .zip(other.data.iter())
56 .map(|(&a, &b)| a ^ b)
57 .collect();
58 BitStreamTensor {
59 data,
60 length: self.length,
61 }
62 }
63
64 pub fn rotate_right(&mut self, shift: usize) {
68 if self.length == 0 || shift.is_multiple_of(self.length) {
69 return;
70 }
71 let mut bits = unpack(self);
72 bits.rotate_right(shift % self.length);
73 *self = pack(&bits);
74 }
75
76 pub fn hamming_distance(&self, other: &BitStreamTensor) -> f32 {
78 assert_eq!(
79 self.length, other.length,
80 "Bitstream lengths must match for Hamming distance."
81 );
82 let xor_count: u64 = crate::simd::fused_xor_popcount_dispatch(&self.data, &other.data);
83 xor_count as f32 / self.length as f32
84 }
85
86 pub fn bundle(vectors: &[&BitStreamTensor]) -> BitStreamTensor {
91 assert!(!vectors.is_empty(), "Cannot bundle zero vectors.");
92 let length = vectors[0].length;
93 let words = vectors[0].data.len();
94
95 if vectors.len() == 1 {
96 return vectors[0].clone();
97 }
98
99 let mut data = vec![0u64; words];
103
104 if vectors.len() == 3 {
105 for (i, item) in data.iter_mut().enumerate().take(words) {
106 let a = vectors[0].data[i];
107 let b = vectors[1].data[i];
108 let c = vectors[2].data[i];
109 *item = (a & b) | (b & c) | (a & c);
110 }
111 } else {
112 let threshold = vectors.len() / 2;
114 for (i, item) in data.iter_mut().enumerate().take(words) {
115 for bit in 0..64 {
116 let mut count = 0;
117 for v in vectors {
118 if (v.data[i] >> bit) & 1 == 1 {
119 count += 1;
120 }
121 }
122 if count > threshold {
123 *item |= 1u64 << bit;
124 }
125 }
126 }
127 }
128
129 BitStreamTensor { data, length }
130 }
131}
132
133pub fn pack(bits: &[u8]) -> BitStreamTensor {
135 let length = bits.len();
136 let words = length.div_ceil(64);
137 let mut data = vec![0_u64; words];
138
139 for (idx, bit) in bits.iter().copied().enumerate() {
140 if bit != 0 {
141 data[idx / 64] |= 1_u64 << (idx % 64);
142 }
143 }
144
145 BitStreamTensor { data, length }
146}
147
148pub fn pack_fast(bits: &[u8]) -> BitStreamTensor {
150 let length = bits.len();
151 let words = length.div_ceil(64);
152 let mut data = vec![0_u64; words];
153
154 for (word_idx, word) in data.iter_mut().enumerate() {
155 let base = word_idx * 64;
156 let chunk = &bits[base..std::cmp::min(base + 64, length)];
157
158 for (byte_idx, byte_chunk) in chunk.chunks(8).enumerate() {
159 let mut packed_byte: u8 = 0;
160 for (bit_idx, &bit) in byte_chunk.iter().enumerate() {
161 packed_byte |= u8::from(bit != 0) << bit_idx;
162 }
163 *word |= (packed_byte as u64) << (byte_idx * 8);
164 }
165 }
166
167 BitStreamTensor { data, length }
168}
169
170pub fn unpack(tensor: &BitStreamTensor) -> Vec<u8> {
172 let mut bits = vec![0_u8; tensor.length];
173
174 for (idx, bit) in bits.iter_mut().enumerate().take(tensor.length) {
175 let word = tensor.data[idx / 64];
176 *bit = ((word >> (idx % 64)) & 1) as u8;
177 }
178
179 bits
180}
181
182pub fn bitwise_and(a: &BitStreamTensor, b: &BitStreamTensor) -> BitStreamTensor {
184 assert_eq!(
185 a.length, b.length,
186 "Bitstream lengths must match for bitwise AND."
187 );
188 assert_eq!(
189 a.data.len(),
190 b.data.len(),
191 "Packed bitstream shapes must match for bitwise AND."
192 );
193
194 let data = a
195 .data
196 .iter()
197 .zip(b.data.iter())
198 .map(|(lhs, rhs)| lhs & rhs)
199 .collect();
200
201 BitStreamTensor {
202 data,
203 length: a.length,
204 }
205}
206
207pub fn swar_popcount_word(mut x: u64) -> u64 {
209 x = x.wrapping_sub((x >> 1) & 0x5555_5555_5555_5555);
210 x = (x & 0x3333_3333_3333_3333) + ((x >> 2) & 0x3333_3333_3333_3333);
211 x = (x + (x >> 4)) & 0x0f0f_0f0f_0f0f_0f0f;
212 x.wrapping_mul(0x0101_0101_0101_0101) >> 56
213}
214
215pub fn popcount_words_portable(data: &[u64]) -> u64 {
217 data.iter().copied().map(swar_popcount_word).sum()
218}
219
220pub fn popcount(tensor: &BitStreamTensor) -> u64 {
222 popcount_words_portable(&tensor.data)
223}
224
225pub fn bernoulli_stream<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u8> {
229 let p = prob.clamp(0.0, 1.0);
230 let mut out = vec![0_u8; length];
231 for bit in &mut out {
232 *bit = if rng.random::<f64>() < p { 1 } else { 0 };
233 }
234 out
235}
236
237pub fn bernoulli_packed<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
242 let p = prob.clamp(0.0, 1.0);
243 let words = length.div_ceil(64);
244 let mut data = vec![0_u64; words];
245 for (word_idx, word) in data.iter_mut().enumerate() {
246 let bits_in_word = std::cmp::min(64, length.saturating_sub(word_idx * 64));
247 for bit in 0..bits_in_word {
248 if rng.random::<f64>() < p {
249 *word |= 1_u64 << bit;
250 }
251 }
252 }
253 data
254}
255
256pub fn bernoulli_packed_fast<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
270 let words = length.div_ceil(64);
271 if prob >= 1.0 {
272 let mut data = vec![u64::MAX; words];
273 let trailing = length % 64;
274 if trailing > 0 {
275 data[words - 1] = (1_u64 << trailing) - 1;
276 }
277 return data;
278 }
279 let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
280 let mut data = vec![0_u64; words];
281 let mut buf = [0_u8; 64];
282
283 for (word_idx, word) in data.iter_mut().enumerate() {
284 let bits_in_word = std::cmp::min(64, length.saturating_sub(word_idx * 64));
285 rng.fill(&mut buf[..bits_in_word]);
286 for (bit, &rb) in buf[..bits_in_word].iter().enumerate() {
287 if rb < threshold {
288 *word |= 1_u64 << bit;
289 }
290 }
291 }
292 data
293}
294
295pub fn bernoulli_packed_simd<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
300 let words = length.div_ceil(64);
301 if prob >= 1.0 {
302 let mut data = vec![u64::MAX; words];
303 let trailing = length % 64;
304 if trailing > 0 {
305 data[words - 1] = (1_u64 << trailing) - 1;
306 }
307 return data;
308 }
309 let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
310 let mut data = vec![0_u64; words];
311 let full_words = length / 64;
312 let mut buf = [0_u8; 64];
313
314 for word in data.iter_mut().take(full_words) {
315 rng.fill(&mut buf);
316 *word = simd_bernoulli_compare(&buf, threshold);
317 }
318
319 if full_words < words {
320 let remaining = length - full_words * 64;
321 rng.fill(&mut buf[..remaining]);
322 let mut tail = 0_u64;
323 for (bit, &rb) in buf[..remaining].iter().enumerate() {
324 if rb < threshold {
325 tail |= 1_u64 << bit;
326 }
327 }
328 data[full_words] = tail;
329 }
330
331 data
332}
333
334pub fn encode_and_popcount<R: Rng + ?Sized>(
340 weight_words: &[u64],
341 prob: f64,
342 length: usize,
343 rng: &mut R,
344) -> u64 {
345 if prob >= 1.0 {
346 let full_words = length / 64;
348 let mut total = 0_u64;
349 for w in weight_words.iter().take(full_words) {
350 total += w.count_ones() as u64;
351 }
352 let trailing = length % 64;
353 if trailing > 0 && full_words < weight_words.len() {
354 let mask = (1_u64 << trailing) - 1;
355 total += (weight_words[full_words] & mask).count_ones() as u64;
356 }
357 return total;
358 }
359 let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
360 let full_words = length / 64;
361 let mut total = 0_u64;
362 let mut buf = [0_u8; 64];
363
364 for &w_word in weight_words.iter().take(full_words) {
365 rng.fill(&mut buf);
366 let encoded = simd_bernoulli_compare(&buf, threshold);
367 total += (encoded & w_word).count_ones() as u64;
368 }
369
370 let remaining = length.saturating_sub(full_words * 64);
371 if remaining > 0 && full_words < weight_words.len() {
372 rng.fill(&mut buf[..remaining]);
373 let mut encoded = 0_u64;
374 for (bit, &rb) in buf[..remaining].iter().enumerate() {
375 if rb < threshold {
376 encoded |= 1_u64 << bit;
377 }
378 }
379 total += (encoded & weight_words[full_words]).count_ones() as u64;
380 }
381
382 total
383}
384
385#[inline]
387fn simd_bernoulli_compare(buf: &[u8], threshold: u8) -> u64 {
388 debug_assert!(buf.len() >= 64, "buffer must contain at least 64 bytes");
389
390 #[cfg(target_arch = "x86_64")]
391 {
392 if is_x86_feature_detected!("avx512bw") {
393 return unsafe { simd::avx512::bernoulli_compare_avx512(buf, threshold) };
395 }
396 if is_x86_feature_detected!("avx2") {
397 let lo = unsafe { simd::avx2::bernoulli_compare_avx2(&buf[0..32], threshold) };
399 let hi = unsafe { simd::avx2::bernoulli_compare_avx2(&buf[32..64], threshold) };
401 return (lo as u64) | ((hi as u64) << 32);
402 }
403 }
404
405 let mut mask = 0_u64;
406 for (bit, &rb) in buf.iter().take(64).enumerate() {
407 if rb < threshold {
408 mask |= 1_u64 << bit;
409 }
410 }
411 mask
412}
413
414pub fn encode_matrix_prob_to_packed<R: Rng + ?Sized>(
418 values: &[f64],
419 rows: usize,
420 cols: usize,
421 length: usize,
422 words: usize,
423 rng: &mut R,
424) -> Vec<Vec<u64>> {
425 let mut packed = Vec::with_capacity(rows * cols);
426 for value in values.iter().take(rows * cols) {
427 let mut row = bernoulli_packed(*value, length, rng);
428 row.resize(words, 0);
429 packed.push(row);
430 }
431 packed
432}
433
434#[cfg(test)]
435mod tests {
436 use super::{
437 bernoulli_packed, bernoulli_packed_fast, bernoulli_packed_simd, bernoulli_stream,
438 bitwise_and, encode_and_popcount, pack, pack_fast, popcount, unpack,
439 };
440
441 #[test]
442 fn pack_unpack_roundtrip() {
443 let bits = vec![1, 0, 1, 1, 0, 1, 0, 0, 1];
444 let packed = pack(&bits);
445 let unpacked = unpack(&packed);
446 assert_eq!(bits, unpacked);
447 }
448
449 #[test]
450 fn pack_fast_matches_pack() {
451 let cases = [0_usize, 1, 7, 8, 9, 63, 64, 65, 127, 128, 256, 1025];
452 for length in cases {
453 let bits: Vec<u8> = (0..length).map(|i| ((i * 7 + 3) % 2) as u8).collect();
454 let slow = pack(&bits);
455 let fast = pack_fast(&bits);
456 assert_eq!(fast.length, slow.length);
457 assert_eq!(fast.data, slow.data, "Mismatch at length={length}");
458 }
459 }
460
461 #[test]
462 fn pack_fast_roundtrip() {
463 let bits: Vec<u8> = (0..2048).map(|i| ((i * 5 + 1) % 2) as u8).collect();
464 let packed = pack_fast(&bits);
465 let unpacked = unpack(&packed);
466 assert_eq!(bits, unpacked);
467 }
468
469 #[test]
470 fn and_and_popcount() {
471 let a = pack(&[1, 0, 1, 1, 0, 0, 1, 1]);
472 let b = pack(&[1, 1, 1, 0, 0, 1, 1, 0]);
473 let c = bitwise_and(&a, &b);
474 assert_eq!(unpack(&c), vec![1, 0, 1, 0, 0, 0, 1, 0]);
475 assert_eq!(popcount(&c), 3);
476 }
477
478 #[test]
479 fn bernoulli_packed_matches_stream_then_pack() {
480 use rand::SeedableRng;
481 use rand_chacha::ChaCha8Rng;
482
483 let prob = 0.35;
484 let length = 200;
485
486 let mut rng1 = ChaCha8Rng::seed_from_u64(999);
487 let stream = bernoulli_stream(prob, length, &mut rng1);
488 let packed_via_stream = pack(&stream).data;
489
490 let mut rng2 = ChaCha8Rng::seed_from_u64(999);
491 let packed_direct = bernoulli_packed(prob, length, &mut rng2);
492
493 assert_eq!(
494 packed_via_stream, packed_direct,
495 "bernoulli_packed must produce bit-identical output"
496 );
497 }
498
499 #[test]
500 fn bernoulli_packed_fast_statistics() {
501 use rand::SeedableRng;
502 use rand_chacha::ChaCha8Rng;
503
504 let prob = 0.35;
505 let length = 10_000;
506 let mut rng = ChaCha8Rng::seed_from_u64(42);
507 let packed = bernoulli_packed_fast(prob, length, &mut rng);
508 let count: u64 = packed.iter().map(|w| w.count_ones() as u64).sum();
509 let measured = count as f64 / length as f64;
510 assert!(
511 (measured - prob).abs() < 0.03,
512 "Expected ~{prob}, got {measured}"
513 );
514 }
515
516 #[test]
517 fn bernoulli_packed_fast_deterministic() {
518 use rand::SeedableRng;
519 use rand_chacha::ChaCha8Rng;
520
521 let mut rng1 = ChaCha8Rng::seed_from_u64(99);
522 let a = bernoulli_packed_fast(0.5, 512, &mut rng1);
523
524 let mut rng2 = ChaCha8Rng::seed_from_u64(99);
525 let b = bernoulli_packed_fast(0.5, 512, &mut rng2);
526
527 assert_eq!(a, b, "Same seed must produce identical output");
528 }
529
530 #[test]
531 fn bernoulli_packed_simd_statistics() {
532 use rand::SeedableRng;
533 use rand_chacha::ChaCha8Rng;
534
535 let prob = 0.35;
536 let length = 10_000;
537 let mut rng = ChaCha8Rng::seed_from_u64(1337);
538 let packed = bernoulli_packed_simd(prob, length, &mut rng);
539 let count: u64 = packed.iter().map(|w| w.count_ones() as u64).sum();
540 let measured = count as f64 / length as f64;
541 assert!(
542 (measured - prob).abs() < 0.03,
543 "Expected ~{prob}, got {measured}"
544 );
545 }
546
547 #[test]
548 fn bernoulli_packed_simd_deterministic() {
549 use rand::SeedableRng;
550 use rand_chacha::ChaCha8Rng;
551
552 let mut rng1 = ChaCha8Rng::seed_from_u64(2026);
553 let a = bernoulli_packed_simd(0.5, 1024, &mut rng1);
554
555 let mut rng2 = ChaCha8Rng::seed_from_u64(2026);
556 let b = bernoulli_packed_simd(0.5, 1024, &mut rng2);
557
558 assert_eq!(a, b, "Same seed must produce identical output");
559 }
560
561 #[test]
562 fn encode_and_popcount_matches_materialized() {
563 use rand::SeedableRng;
564 use rand_xoshiro::Xoshiro256PlusPlus;
565
566 let prob = 0.41;
567 let lengths = [63_usize, 64, 65, 1003, 1024];
568 for length in lengths {
569 let words = length.div_ceil(64);
570 let weights: Vec<u64> = (0..words)
571 .map(|i| (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xA5A5_A5A5_5A5A_5A5A)
572 .collect();
573
574 let mut rng1 = Xoshiro256PlusPlus::seed_from_u64(2026);
575 let fused = encode_and_popcount(&weights, prob, length, &mut rng1);
576
577 let mut rng2 = Xoshiro256PlusPlus::seed_from_u64(2026);
578 let encoded = bernoulli_packed_simd(prob, length, &mut rng2);
579 let expected: u64 = encoded
580 .iter()
581 .zip(weights.iter())
582 .map(|(&e, &w)| (e & w).count_ones() as u64)
583 .sum();
584
585 assert_eq!(fused, expected, "Mismatch at length={length}");
586 }
587 }
588}