1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
import numpy as np
class my_conv(object): def __init__(self, input_data, weight_data, stride, padding = 'SAME'): self.input = np.asarray(input_data, np.float32) self.weights = np.asarray(weight_data, np.float32) self.stride = stride self.padding = padding def my_conv2d(self): """ self.input: c * h * w # 输入的数据格式 self.weights: c * h * w """ [c, h, w] = self.input.shape [kc, k, _] = self.weights.shape assert c == kc output = [] for i in range(c): f_map = self.input[i] kernel = self.weights[i] rs = self.compute_conv(f_map, kernel) if output == []: output = rs else: output += rs return output def compute_conv(self, fm, kernel): [h, w] = fm.shape [k, _] = kernel.shape
if self.padding == 'SAME': pad_h = (self.stride * (h - 1) + k - h) // 2 pad_w = (self.stride * (w - 1) + k - w) // 2 rs_h = h rs_w = w elif self.padding == 'VALID': pad_h = 0 pad_w = 0 rs_h = (h - k) // self.stride + 1 rs_w = (w - k) // self.stride + 1 elif self.padding == 'FULL': pad_h = k - 1 pad_w = k - 1 rs_h = (h + 2 * pad_h - k) // self.stride + 1 rs_w = (w + 2 * pad_w - k) // self.stride + 1 padding_fm = np.zeros([h + 2 * pad_h, w + 2 * pad_w], np.float32) padding_fm[pad_h:pad_h+h, pad_w:pad_w+w] = fm rs = np.zeros([rs_h, rs_w], np.float32)
for i in range(rs_h): for j in range(rs_w): roi = padding_fm[i*self.stride:(i*self.stride + k), j*self.stride:(j*self.stride + k)] rs[i, j] = np.sum(roi * kernel) return rs
if __name__=='__main__': input_data = [ [ [1, 0, 1, 2, 1], [0, 2, 1, 0, 1], [1, 1, 0, 2, 0], [2, 2, 1, 1, 0], [2, 0, 1, 2, 0], ], [ [2, 0, 2, 1, 1], [0, 1, 0, 0, 2], [1, 0, 0, 2, 1], [1, 1, 2, 1, 0], [1, 0, 1, 1, 1],
], ] weight_data = [ [ [1, 0, 1], [-1, 1, 0], [0, -1, 0], ], [ [-1, 0, 1], [0, 0, 1], [1, 1, 1], ] ] conv = my_conv(input_data, weight_data, 1, 'SAME') print(conv.my_conv2d())
|