Skip to content

Commit

Permalink
Coding two-path kMobileNet.
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopauloschuler committed Jun 7, 2022
1 parent 3baea02 commit 22c1975
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions cai/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,11 @@ def kMobileNet(input_shape=None,
pooling=None,
classes=1000,
activation=keras.activations.swish,
kType=0):
kType=0,
l_ratio=0.0,
ab_ratio=0.0,
skip_stride_cnt=-1
):
"""Instantiates the k optimized MobileNet architecture.
# Arguments
Expand Down Expand Up @@ -495,6 +499,9 @@ def kMobileNet(input_shape=None,
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
kType: k optimized convolutional type.
l_ratio: proportion of first layer filters dedicated to light.
ab_ratio: proportion of first layer filters dedicated to color.
skip_stride_cnt: number of layers to skip stride. This parameter is used with smalll images such as CIFAR-10.
# Returns
A Keras model instance.
Expand All @@ -507,23 +514,44 @@ def kMobileNet(input_shape=None,
"""
img_input = keras.layers.Input(shape=input_shape)

x = _conv_block(img_input, 32, alpha, strides=(2, 2))
if backend.image_data_format() == 'channels_first':
channel_axis = 1
else:
channel_axis = 3

local_strides = (1, 1) if (skip_stride_cnt >=0) else (2, 2)

if (l_ratio > 0.0) and (ab_ratio>0.0):
l_branch = _conv_block(img_input, int(round(32*l_ratio)), alpha, strides=local_strides)
ab_branch = _conv_block(img_input, int(round(32*ab_ratio)), alpha, strides=local_strides)
x = keras.layers.Concatenate(axis=channel_axis, name='l-ab-paths-concat')([l_branch, ab_branch])
elif (l_ratio > 0.0) and (ab_ratio<=0.0):
x = _conv_block(img_input, int(round(32*l_ratio)), alpha, strides=local_strides)
elif (l_ratio <= 0.0) and (ab_ratio>0.0):
x = _conv_block(img_input, int(round(32*ab_ratio)), alpha, strides=local_strides)
else:
x = _conv_block(img_input, 32, alpha, strides=local_strides)

x = kdepthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1, activation=activation, kType=kType)

x = kdepthwise_conv_block(x, 128, alpha, depth_multiplier, strides=(2, 2), block_id=2, activation=activation, kType=kType)
local_strides = (1, 1) if (skip_stride_cnt >=1) else (2, 2)
x = kdepthwise_conv_block(x, 128, alpha, depth_multiplier, strides=local_strides, block_id=2, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3, activation=activation, kType=kType)

x = kdepthwise_conv_block(x, 256, alpha, depth_multiplier, strides=(2, 2), block_id=4, activation=activation, kType=kType)
local_strides = (1, 1) if (skip_stride_cnt >=2) else (2, 2)
x = kdepthwise_conv_block(x, 256, alpha, depth_multiplier, strides=local_strides, block_id=4, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5, activation=activation, kType=kType)

x = kdepthwise_conv_block(x, 512, alpha, depth_multiplier, strides=(2, 2), block_id=6, activation=activation, kType=kType)
local_strides = (1, 1) if (skip_stride_cnt >=3) else (2, 2)
x = kdepthwise_conv_block(x, 512, alpha, depth_multiplier, strides=local_strides, block_id=6, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11, activation=activation, kType=kType)

x = kdepthwise_conv_block(x, 1024, alpha, depth_multiplier, strides=(2, 2), block_id=12, activation=activation, kType=kType)
local_strides = (1, 1) if (skip_stride_cnt >=4) else (2, 2)
x = kdepthwise_conv_block(x, 1024, alpha, depth_multiplier, strides=local_strides, block_id=12, activation=activation, kType=kType)
x = kdepthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13, activation=activation, kType=kType)

if include_top:
Expand Down

0 comments on commit 22c1975

Please sign in to comment.