1use rand::{RngExt, SeedableRng};
12use rand_chacha::ChaCha8Rng;
13use rayon::prelude::*;
14
15#[derive(Clone, Debug)]
20pub struct Conv2DLayer {
21 pub in_channels: usize,
22 pub out_channels: usize,
23 pub kernel_size: usize,
24 pub stride: usize,
25 pub padding: usize,
26 pub kernels: Vec<f64>,
28}
29
30impl Conv2DLayer {
31 pub fn new(
32 in_channels: usize,
33 out_channels: usize,
34 kernel_size: usize,
35 stride: usize,
36 padding: usize,
37 seed: u64,
38 ) -> Self {
39 let mut rng = ChaCha8Rng::seed_from_u64(seed);
40 let size = out_channels * in_channels * kernel_size * kernel_size;
41 let kernels: Vec<f64> = (0..size).map(|_| rng.random::<f64>()).collect();
42 Self {
43 in_channels,
44 out_channels,
45 kernel_size,
46 stride,
47 padding,
48 kernels,
49 }
50 }
51
52 pub fn forward(&self, input: &[f64], h: usize, w: usize) -> (Vec<f64>, usize, usize) {
54 let k = self.kernel_size;
55 let h_out = (h + 2 * self.padding - k) / self.stride + 1;
56 let w_out = (w + 2 * self.padding - k) / self.stride + 1;
57 let c_in = self.in_channels;
58 let filter_size = c_in * k * k;
59
60 let padded = if self.padding > 0 {
62 let ph = h + 2 * self.padding;
63 let pw = w + 2 * self.padding;
64 let mut p = vec![0.0; c_in * ph * pw];
65 for c in 0..c_in {
66 for i in 0..h {
67 for j in 0..w {
68 p[c * ph * pw + (i + self.padding) * pw + (j + self.padding)] =
69 input[c * h * w + i * w + j];
70 }
71 }
72 }
73 (p, ph, pw)
74 } else {
75 (input.to_vec(), h, w)
76 };
77 let (ref inp, ph, pw) = padded;
78
79 let mut output = vec![0.0; self.out_channels * h_out * w_out];
80
81 output
82 .par_chunks_exact_mut(h_out * w_out)
83 .enumerate()
84 .for_each(|(oc, out_plane)| {
85 let filter = &self.kernels[oc * filter_size..(oc + 1) * filter_size];
86 for i in 0..h_out {
87 let mut j = 0;
88 while j + 3 < w_out {
89 let hs = i * self.stride;
90 let mut acc0 = 0.0;
91 let mut acc1 = 0.0;
92 let mut acc2 = 0.0;
93 let mut acc3 = 0.0;
94 for c in 0..c_in {
95 let input_offset = c * ph * pw;
96 let filter_offset = c * k * k;
97 for ki in 0..k {
98 let row_off = input_offset + (hs + ki) * pw;
99 let f_row_off = filter_offset + ki * k;
100 let filter_row = &filter[f_row_off..f_row_off + k];
101
102 acc0 += crate::simd::dot_f64_dispatch(
103 &inp[row_off + j * self.stride..row_off + j * self.stride + k],
104 filter_row,
105 );
106 acc1 += crate::simd::dot_f64_dispatch(
107 &inp[row_off + (j + 1) * self.stride
108 ..row_off + (j + 1) * self.stride + k],
109 filter_row,
110 );
111 acc2 += crate::simd::dot_f64_dispatch(
112 &inp[row_off + (j + 2) * self.stride
113 ..row_off + (j + 2) * self.stride + k],
114 filter_row,
115 );
116 acc3 += crate::simd::dot_f64_dispatch(
117 &inp[row_off + (j + 3) * self.stride
118 ..row_off + (j + 3) * self.stride + k],
119 filter_row,
120 );
121 }
122 }
123 out_plane[i * w_out + j] = acc0;
124 out_plane[i * w_out + j + 1] = acc1;
125 out_plane[i * w_out + j + 2] = acc2;
126 out_plane[i * w_out + j + 3] = acc3;
127 j += 4;
128 }
129 while j < w_out {
130 let hs = i * self.stride;
131 let ws = j * self.stride;
132 let mut acc = 0.0;
133 for c in 0..c_in {
134 let input_offset = c * ph * pw;
135 let filter_offset = c * k * k;
136 for ki in 0..k {
137 let inp_row = &inp[input_offset + (hs + ki) * pw + ws
138 ..input_offset + (hs + ki) * pw + ws + k];
139 let filter_row =
140 &filter[filter_offset + ki * k..filter_offset + (ki + 1) * k];
141 acc += crate::simd::dot_f64_dispatch(inp_row, filter_row);
142 }
143 }
144 out_plane[i * w_out + j] = acc;
145 j += 1;
146 }
147 }
148 });
149
150 (output, h_out, w_out)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn output_shape_no_padding() {
160 let conv = Conv2DLayer::new(1, 2, 3, 1, 0, 42);
161 let input = vec![0.5; 8 * 8];
162 let (out, h, w) = conv.forward(&input, 8, 8);
163 assert_eq!(h, 6);
164 assert_eq!(w, 6);
165 assert_eq!(out.len(), 2 * 6 * 6);
166 }
167
168 #[test]
169 fn output_shape_with_padding() {
170 let conv = Conv2DLayer::new(1, 2, 3, 1, 1, 42);
171 let input = vec![0.5; 8 * 8];
172 let (out, h, w) = conv.forward(&input, 8, 8);
173 assert_eq!(h, 8);
174 assert_eq!(w, 8);
175 assert_eq!(out.len(), 2 * 8 * 8);
176 }
177
178 #[test]
179 fn all_ones_kernel() {
180 let mut conv = Conv2DLayer::new(1, 1, 3, 1, 0, 42);
181 conv.kernels = vec![1.0; 9];
182 let input = vec![1.0; 5 * 5];
183 let (out, _, _) = conv.forward(&input, 5, 5);
184 assert!((out[0] - 9.0).abs() < 1e-10);
186 }
187
188 #[test]
189 fn stride_2() {
190 let conv = Conv2DLayer::new(1, 1, 3, 2, 0, 42);
191 let input = vec![0.5; 8 * 8];
192 let (_, h, w) = conv.forward(&input, 8, 8);
193 assert_eq!(h, 3);
194 assert_eq!(w, 3);
195 }
196}