1#[derive(Clone, Debug)]
15pub struct SobolEngine {
16 state: u32,
17 index: u32,
18 direction: [u32; 32],
19}
20
21impl SobolEngine {
22 pub fn new(seed: u32) -> Self {
24 let mut direction = [0u32; 32];
25 for (i, d) in direction.iter_mut().enumerate() {
26 *d = 1u32 << (31 - i);
27 }
28 Self {
29 state: seed,
30 index: 0,
31 direction,
32 }
33 }
34
35 pub fn sample(&mut self) -> f64 {
37 let c = self.index.trailing_ones() as usize;
38 self.state ^= self.direction[c.min(31)];
39 self.index += 1;
40 self.state as f64 / (1u64 << 32) as f64
41 }
42
43 pub fn reset(&mut self, seed: u32) {
44 self.state = seed;
45 self.index = 0;
46 }
47}
48
49pub fn generate_sobol_bitstream(p: f64, length: usize, seed: u32) -> Vec<u8> {
51 assert!((0.0..=1.0).contains(&p), "p must be in [0,1]");
52 let mut engine = SobolEngine::new(seed);
53 (0..length)
54 .map(|_| if engine.sample() < p { 1u8 } else { 0u8 })
55 .collect()
56}
57
58pub fn generate_sobol_packed(p: f64, length: usize, seed: u32) -> Vec<u64> {
60 let bits = generate_sobol_bitstream(p, length, seed);
61 crate::bitstream::pack_fast(&bits).data
62}
63
64#[derive(Clone, Debug)]
75pub struct HaltonEngine {
76 index: u64,
77 base: u32,
78 seed: u64,
79}
80
81impl HaltonEngine {
82 pub fn new(base: u32, seed: u64) -> Self {
86 assert!(base >= 2, "Halton base must be >= 2");
87 Self {
88 index: seed,
89 base,
90 seed,
91 }
92 }
93
94 pub fn sample(&mut self) -> f64 {
96 self.index += 1;
97 Self::radical_inverse(self.index, self.base)
98 }
99
100 fn radical_inverse(mut n: u64, base: u32) -> f64 {
105 let mut result = 0.0_f64;
106 let mut denom = 1.0_f64;
107 let b = f64::from(base);
108 while n > 0 {
109 denom *= b;
110 let remainder = n % u64::from(base);
111 result += remainder as f64 / denom;
112 n /= u64::from(base);
113 }
114 result
115 }
116
117 pub fn reset(&mut self) {
119 self.index = self.seed;
120 }
121}
122
123pub fn generate_halton_bitstream(p: f64, length: usize, base: u32, seed: u64) -> Vec<u8> {
125 assert!((0.0..=1.0).contains(&p), "p must be in [0,1]");
126 let mut engine = HaltonEngine::new(base, seed);
127 (0..length)
128 .map(|_| if engine.sample() < p { 1u8 } else { 0u8 })
129 .collect()
130}
131
132pub fn generate_halton_packed(p: f64, length: usize, base: u32, seed: u64) -> Vec<u64> {
134 let bits = generate_halton_bitstream(p, length, base, seed);
135 crate::bitstream::pack_fast(&bits).data
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
145 fn deterministic_output() {
146 let a = generate_sobol_bitstream(0.5, 128, 0);
147 let b = generate_sobol_bitstream(0.5, 128, 0);
148 assert_eq!(a, b);
149 }
150
151 #[test]
152 fn p_zero_all_zeros() {
153 let bits = generate_sobol_bitstream(0.0, 256, 42);
154 assert!(bits.iter().all(|&b| b == 0));
155 }
156
157 #[test]
158 fn p_one_all_ones() {
159 let bits = generate_sobol_bitstream(1.0, 256, 42);
160 assert!(bits.iter().all(|&b| b == 1));
161 }
162
163 #[test]
164 fn proportion_close_to_p() {
165 let p = 0.3;
166 let bits = generate_sobol_bitstream(p, 4096, 0);
167 let ones = bits.iter().filter(|&&b| b == 1).count();
168 let ratio = ones as f64 / 4096.0;
169 assert!(
170 (ratio - p).abs() < 0.02,
171 "Sobol ratio {ratio} should be close to {p}"
172 );
173 }
174
175 #[test]
176 fn lower_discrepancy_than_bernoulli() {
177 let bits = generate_sobol_bitstream(0.5, 1024, 0);
179 let ones = bits.iter().filter(|&&b| b == 1).count();
180 let error = (ones as f64 / 1024.0 - 0.5).abs();
181 assert!(error < 0.05, "Sobol discrepancy {error} too high");
182 }
183
184 #[test]
185 fn packed_roundtrip() {
186 let bits = generate_sobol_bitstream(0.7, 200, 99);
187 let packed = generate_sobol_packed(0.7, 200, 99);
188 let unpacked = crate::bitstream::unpack(&crate::bitstream::BitStreamTensor {
189 data: packed,
190 length: 200,
191 });
192 assert_eq!(bits, unpacked);
193 }
194
195 #[test]
198 fn halton_deterministic() {
199 let a = generate_halton_bitstream(0.5, 128, 2, 0);
200 let b = generate_halton_bitstream(0.5, 128, 2, 0);
201 assert_eq!(a, b);
202 }
203
204 #[test]
205 fn halton_p_zero_all_zeros() {
206 let bits = generate_halton_bitstream(0.0, 256, 2, 0);
207 assert!(bits.iter().all(|&b| b == 0));
208 }
209
210 #[test]
211 fn halton_p_one_all_ones() {
212 let bits = generate_halton_bitstream(1.0, 256, 2, 0);
213 assert!(bits.iter().all(|&b| b == 1));
214 }
215
216 #[test]
217 fn halton_base2_proportion() {
218 let p = 0.4;
219 let bits = generate_halton_bitstream(p, 4096, 2, 0);
220 let ones = bits.iter().filter(|&&b| b == 1).count();
221 let ratio = ones as f64 / 4096.0;
222 assert!(
223 (ratio - p).abs() < 0.02,
224 "Halton base-2 ratio {ratio} should be close to {p}"
225 );
226 }
227
228 #[test]
229 fn halton_base3_proportion() {
230 let p = 0.6;
231 let bits = generate_halton_bitstream(p, 4096, 3, 0);
232 let ones = bits.iter().filter(|&&b| b == 1).count();
233 let ratio = ones as f64 / 4096.0;
234 assert!(
235 (ratio - p).abs() < 0.03,
236 "Halton base-3 ratio {ratio} should be close to {p}"
237 );
238 }
239
240 #[test]
241 fn halton_packed_roundtrip() {
242 let bits = generate_halton_bitstream(0.7, 200, 2, 99);
243 let packed = generate_halton_packed(0.7, 200, 2, 99);
244 let unpacked = crate::bitstream::unpack(&crate::bitstream::BitStreamTensor {
245 data: packed,
246 length: 200,
247 });
248 assert_eq!(bits, unpacked);
249 }
250
251 #[test]
252 fn halton_radical_inverse_base2() {
253 let r1 = HaltonEngine::radical_inverse(1, 2);
255 let r2 = HaltonEngine::radical_inverse(2, 2);
256 let r3 = HaltonEngine::radical_inverse(3, 2);
257 assert!((r1 - 0.5).abs() < 1e-10);
258 assert!((r2 - 0.25).abs() < 1e-10);
259 assert!((r3 - 0.75).abs() < 1e-10);
260 }
261
262 #[test]
263 fn halton_radical_inverse_base3() {
264 let r1 = HaltonEngine::radical_inverse(1, 3);
266 let r2 = HaltonEngine::radical_inverse(2, 3);
267 let r3 = HaltonEngine::radical_inverse(3, 3);
268 assert!((r1 - 1.0 / 3.0).abs() < 1e-10);
269 assert!((r2 - 2.0 / 3.0).abs() < 1e-10);
270 assert!((r3 - 1.0 / 9.0).abs() < 1e-10);
271 }
272
273 #[test]
276 fn sobol_converges_faster_than_halton() {
277 let p = 0.5;
279 let n = 4096usize;
280
281 let sobol_bits = generate_sobol_bitstream(p, n, 0);
282 let sobol_ones = sobol_bits.iter().filter(|&&b| b == 1).count();
283 let sobol_err = (sobol_ones as f64 / n as f64 - p).abs();
284
285 let halton_bits = generate_halton_bitstream(p, n, 2, 0);
286 let halton_ones = halton_bits.iter().filter(|&&b| b == 1).count();
287 let halton_err = (halton_ones as f64 / n as f64 - p).abs();
288
289 assert!(sobol_err < 0.02, "Sobol error {sobol_err} too high");
291 assert!(halton_err < 0.02, "Halton error {halton_err} too high");
292 }
293}