1use rayon::prelude::*;
30
31const CHUNK_SIZE: usize = 512;
42
43pub fn parallel_csr_spmv_add(
44 indptr: &[i32],
45 indices: &[i32],
46 data: &[f64],
47 x: &[f64],
48 y: &mut [f64],
49) {
50 y.par_chunks_mut(CHUNK_SIZE)
51 .enumerate()
52 .for_each(|(chunk_idx, chunk)| {
53 let row_start = chunk_idx * CHUNK_SIZE;
54 for (i, yi) in chunk.iter_mut().enumerate() {
55 let r = row_start + i;
56 let start = indptr[r] as usize;
57 let end = indptr[r + 1] as usize;
58 let mut sum: f64 = 0.0;
59 for k in start..end {
60 let col = indices[k] as usize;
61 sum += data[k] * x[col];
62 }
63 *yi += sum;
64 }
65 });
66}
67
68#[allow(clippy::too_many_arguments)]
79pub fn parallel_csr_multi_spmv_add(
80 indptr_blocks: &[&[i32]],
81 indices_blocks: &[&[i32]],
82 data_blocks: &[&[f64]],
83 x_blocks: &[&[f64]],
84 y: &mut [f64],
85) {
86 let n_blocks = indptr_blocks.len();
87 debug_assert_eq!(n_blocks, indices_blocks.len());
88 debug_assert_eq!(n_blocks, data_blocks.len());
89 debug_assert_eq!(n_blocks, x_blocks.len());
90
91 y.par_chunks_mut(CHUNK_SIZE)
92 .enumerate()
93 .for_each(|(chunk_idx, chunk)| {
94 let row_start = chunk_idx * CHUNK_SIZE;
95 for (i, yi) in chunk.iter_mut().enumerate() {
96 let r = row_start + i;
97 let mut sum: f64 = 0.0;
98 for b in 0..n_blocks {
99 let indptr = indptr_blocks[b];
100 let indices = indices_blocks[b];
101 let data = data_blocks[b];
102 let x = x_blocks[b];
103 let start = indptr[r] as usize;
104 let end = indptr[r + 1] as usize;
105 for k in start..end {
106 let col = indices[k] as usize;
107 sum += data[k] * x[col];
108 }
109 }
110 *yi += sum;
111 }
112 });
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
121 fn test_basic_csr_spmv() {
122 let indptr: Vec<i32> = vec![0, 2, 3, 5];
123 let indices: Vec<i32> = vec![0, 2, 1, 0, 2];
124 let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
125 let x: Vec<f64> = vec![1.0, 1.0, 1.0];
126 let mut y: Vec<f64> = vec![0.0, 0.0, 0.0];
127 parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
128 assert_eq!(y, vec![3.0, 3.0, 9.0]);
129 }
130
131 #[test]
133 fn test_empty_row() {
134 let indptr: Vec<i32> = vec![0, 1, 1, 2];
135 let indices: Vec<i32> = vec![0, 1];
136 let data: Vec<f64> = vec![10.0, 20.0];
137 let x: Vec<f64> = vec![1.0, 2.0];
138 let mut y: Vec<f64> = vec![100.0, 100.0, 100.0];
139 parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
140 assert_eq!(y, vec![110.0, 100.0, 140.0]);
141 }
142
143 #[test]
145 fn test_accumulates_into_y() {
146 let indptr: Vec<i32> = vec![0, 1, 2];
147 let indices: Vec<i32> = vec![0, 0];
148 let data: Vec<f64> = vec![3.0, 5.0];
149 let x: Vec<f64> = vec![2.0];
150 let mut y: Vec<f64> = vec![0.0, 0.0];
151 parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
152 assert_eq!(y, vec![6.0, 10.0]);
153 parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
154 assert_eq!(y, vec![12.0, 20.0]);
155 }
156
157 #[test]
159 fn test_large_dense_diagonal() {
160 let n = 1024;
161 let indptr: Vec<i32> = (0..=n).map(|i| i as i32).collect();
162 let indices: Vec<i32> = (0..n).map(|i| i as i32).collect();
163 let data: Vec<f64> = (0..n).map(|i| (i as f64) + 1.0).collect();
164 let x: Vec<f64> = vec![1.0; n];
165 let mut y: Vec<f64> = vec![0.0; n];
166 parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
167 for i in 0..n {
168 assert_eq!(y[i], (i as f64) + 1.0);
169 }
170 }
171
172 #[test]
174 fn test_multi_spmv_matches_sequential() {
175 let indptr0: Vec<i32> = vec![0, 1, 2, 3, 4];
177 let indices0: Vec<i32> = vec![0, 1, 2, 0];
178 let data0: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
179 let x0: Vec<f64> = vec![10.0, 20.0, 30.0];
180
181 let indptr1: Vec<i32> = vec![0, 0, 1, 1, 2];
182 let indices1: Vec<i32> = vec![1, 2];
183 let data1: Vec<f64> = vec![5.0, 6.0];
184 let x1: Vec<f64> = vec![100.0, 200.0, 300.0];
185
186 let indptr2: Vec<i32> = vec![0, 1, 1, 2, 3];
187 let indices2: Vec<i32> = vec![2, 0, 1];
188 let data2: Vec<f64> = vec![7.0, 8.0, 9.0];
189 let x2: Vec<f64> = vec![1000.0, 2000.0, 3000.0];
190
191 let mut y_seq = vec![0.0_f64; 4];
193 parallel_csr_spmv_add(&indptr0, &indices0, &data0, &x0, &mut y_seq);
194 parallel_csr_spmv_add(&indptr1, &indices1, &data1, &x1, &mut y_seq);
195 parallel_csr_spmv_add(&indptr2, &indices2, &data2, &x2, &mut y_seq);
196
197 let mut y_batched = vec![0.0_f64; 4];
199 let indptrs: Vec<&[i32]> = vec![&indptr0, &indptr1, &indptr2];
200 let indices_b: Vec<&[i32]> = vec![&indices0, &indices1, &indices2];
201 let data_b: Vec<&[f64]> = vec![&data0, &data1, &data2];
202 let xs: Vec<&[f64]> = vec![&x0, &x1, &x2];
203 parallel_csr_multi_spmv_add(&indptrs, &indices_b, &data_b, &xs, &mut y_batched);
204
205 assert_eq!(y_seq, y_batched);
206 }
207}