| | """ |
| | Minimal GoogLeNet (Inception V1) in MLX, up to inception4e. |
| | Loads weights from a torchvision-exported npz (see export_googlenet_npz.py). |
| | """ |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | import numpy as np |
| |
|
| |
|
| | def _conv_bn(in_ch, out_ch, kernel_size, stride=1, padding=0): |
| | return nn.Sequential( |
| | nn.Conv2d( |
| | in_ch, |
| | out_ch, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=padding, |
| | bias=False, |
| | ), |
| | nn.BatchNorm(out_ch, eps=1e-3, momentum=0.1), |
| | nn.ReLU(), |
| | ) |
| |
|
| |
|
| | class Inception(nn.Module): |
| | def __init__(self, in_ch, ch1, ch3r, ch3, ch5r, ch5, pool_proj): |
| | super().__init__() |
| | self.branch1 = _conv_bn(in_ch, ch1, 1) |
| |
|
| | self.branch2_1 = _conv_bn(in_ch, ch3r, 1) |
| | self.branch2_2 = _conv_bn(ch3r, ch3, 3, padding=1) |
| |
|
| | self.branch3_1 = _conv_bn(in_ch, ch5r, 1) |
| | |
| | self.branch3_2 = _conv_bn(ch5r, ch5, 3, padding=1) |
| |
|
| | self.branch4_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) |
| | self.branch4_2 = _conv_bn(in_ch, pool_proj, 1) |
| |
|
| | def __call__(self, x): |
| | b1 = self.branch1(x) |
| | b2 = self.branch2_2(self.branch2_1(x)) |
| | b3 = self.branch3_2(self.branch3_1(x)) |
| | b4 = self.branch4_2(self.branch4_pool(x)) |
| | return mx.concatenate([b1, b2, b3, b4], axis=-1) |
| |
|
| |
|
| | class GoogLeNet(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.conv1 = _conv_bn(3, 64, 7, stride=2, padding=3) |
| | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) |
| |
|
| | self.conv2 = _conv_bn(64, 64, 1) |
| | self.conv3 = _conv_bn(64, 192, 3, padding=1) |
| | self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) |
| |
|
| | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) |
| | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) |
| | self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) |
| |
|
| | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) |
| | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) |
| | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) |
| | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) |
| | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) |
| | self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| |
|
| | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) |
| | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) |
| |
|
| | def forward_with_endpoints(self, x): |
| | endpoints = {} |
| | x = self.conv1(x) |
| | x = self.maxpool1(x) |
| |
|
| | x = self.conv2(x) |
| | x = self.conv3(x) |
| | x = self.maxpool2(x) |
| |
|
| | x = self.inception3a(x) |
| | endpoints["inception3a"] = x |
| | x = self.inception3b(x) |
| | endpoints["inception3b"] = x |
| | x = self.maxpool3(x) |
| |
|
| | x = self.inception4a(x) |
| | endpoints["inception4a"] = x |
| | x = self.inception4b(x) |
| | endpoints["inception4b"] = x |
| | x = self.inception4c(x) |
| | endpoints["inception4c"] = x |
| | x = self.inception4d(x) |
| | endpoints["inception4d"] = x |
| | x = self.inception4e(x) |
| | endpoints["inception4e"] = x |
| | x = self.maxpool4(x) |
| |
|
| | x = self.inception5a(x) |
| | endpoints["inception5a"] = x |
| | x = self.inception5b(x) |
| | endpoints["inception5b"] = x |
| | return x, endpoints |
| |
|
| | def __call__(self, x): |
| | _, endpoints = self.forward_with_endpoints(x) |
| | return endpoints |
| |
|
| | def load_npz(self, path: str): |
| | data = np.load(path) |
| |
|
| | def load_weight(key, target_module, param_name="weight", transpose=False): |
| | |
| | if key in data: |
| | w = data[key] |
| | |
| | elif f"{key}_int8" in data: |
| | w_int8 = data[f"{key}_int8"] |
| | scale = data[f"{key}_scale"] |
| | |
| | w = w_int8.astype(scale.dtype) * scale |
| | else: |
| | raise ValueError(f"Missing key {key} (or {key}_int8) in npz") |
| |
|
| | |
| | if transpose and w.ndim == 4: |
| | w = np.transpose(w, (0, 2, 3, 1)) |
| | |
| | |
| | target_module[param_name] = mx.array(w) |
| |
|
| | def load_conv_bn(prefix, seq_mod: nn.Sequential): |
| | conv = seq_mod.layers[0] |
| | bn = seq_mod.layers[1] |
| | |
| | load_weight(f"{prefix}.conv.weight", conv, transpose=True) |
| | |
| | load_weight(f"{prefix}.bn.weight", bn) |
| | load_weight(f"{prefix}.bn.bias", bn, param_name="bias") |
| | load_weight(f"{prefix}.bn.running_mean", bn, param_name="running_mean") |
| | load_weight(f"{prefix}.bn.running_var", bn, param_name="running_var") |
| |
|
| | load_conv_bn("conv1", self.conv1) |
| | load_conv_bn("conv2", self.conv2) |
| | load_conv_bn("conv3", self.conv3) |
| |
|
| | def load_inception(prefix, module: Inception): |
| | load_conv_bn(f"{prefix}.branch1", module.branch1) |
| | load_conv_bn(f"{prefix}.branch2.0", module.branch2_1) |
| | load_conv_bn(f"{prefix}.branch2.1", module.branch2_2) |
| | load_conv_bn(f"{prefix}.branch3.0", module.branch3_1) |
| | load_conv_bn(f"{prefix}.branch3.1", module.branch3_2) |
| | load_conv_bn(f"{prefix}.branch4.1", module.branch4_2) |
| |
|
| | load_inception("inception3a", self.inception3a) |
| | load_inception("inception3b", self.inception3b) |
| | load_inception("inception4a", self.inception4a) |
| | load_inception("inception4b", self.inception4b) |
| | load_inception("inception4c", self.inception4c) |
| | load_inception("inception4d", self.inception4d) |
| | load_inception("inception4e", self.inception4e) |
| | load_inception("inception5a", self.inception5a) |
| | load_inception("inception5b", self.inception5b) |
| |
|