close
文章目錄

一、前期工作

1. 設置GPU

2. 導入數據

3. 查看數據

二、數據預處理

1. 加載數據

2. 可視化數據

3. 再次檢查數據

4. 配置數據集

三、殘差網絡(ResNet)介紹

1. 殘差網絡解決了什麼

2. ResNet-50介紹

四、構建ResNet-50網絡模型

五、編譯

六、訓練模型

七、模型評估

八、保存and加載模型

九、預測

一、前期工作1. 設置GPU

如果使用的是CPU可以注釋掉這部分的代碼。

importtensorflowastfgpus=tf.config.list_physical_devices("GPU")ifgpus:tf.config.experimental.set_memory_growth(gpus[0],True)#設置GPU顯存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")2. 導入數據importmatplotlib.pyplotasplt#支持中文plt.rcParams['font.sans-serif']=['SimHei']#用來正常顯示中文標籤plt.rcParams['axes.unicode_minus']=False#用來正常顯示負號importos,PIL#設置隨機種子儘可能使結果可以重現importnumpyasnpnp.random.seed(1)#設置隨機種子儘可能使結果可以重現importtensorflowastftf.random.set_seed(1)fromtensorflowimportkerasfromtensorflow.kerasimportlayers,modelsimportpathlibdata_dir="D:/jupyternotebook/DL-100-days/datasets/bird_photos"data_dir=pathlib.Path(data_dir)3. 查看數據image_count=len(list(data_dir.glob('*/*')))print("圖片總數為:",image_count)圖片總數為:565二、數據預處理
文件夾數量Bananaquit166 張Black Throated Bushtiti111 張Black skimmer122 張Cockatoo166張
1. 加載數據

使用image_dataset_from_directory方法將磁盤中的數據加載到tf.data.Dataset中

batch_size=8img_height=224img_width=224

TensorFlow版本是2.2.0的同學可能會遇到module 'tensorflow.keras.preprocessing' has no attribute 'image_dataset_from_directory'的報錯,升級一下TensorFlow就OK了。

"""關於image_dataset_from_directory()的詳細介紹可以參考文章:https://mtyjkh.blog.csdn.net/article/details/117018789"""train_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height,img_width),batch_size=batch_size)Found 565 files belonging to 4 classes.Using 452 files for training."""關於image_dataset_from_directory()的詳細介紹可以參考文章:https://mtyjkh.blog.csdn.net/article/details/117018789"""val_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height,img_width),batch_size=batch_size)Found 565 files belonging to 4 classes.Using 113 files for validation.

我們可以通過class_names輸出數據集的標籤。標籤將按字母順序對應於目錄名稱。

class_names=train_ds.class_namesprint(class_names)['Bananaquit', 'Black Throated Bushtiti', 'Black skimmer', 'Cockatoo']2. 可視化數據plt.figure(figsize=(10,5))#圖形的寬為10高為5plt.suptitle("微信公眾號:K同學啊")forimages,labelsintrain_ds.take(1):foriinrange(8):ax=plt.subplot(2,4,i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")plt.imshow(images[1].numpy().astype("uint8"))3. 再次檢查數據forimage_batch,labels_batchintrain_ds:print(image_batch.shape)print(labels_batch.shape)break(8, 224, 224, 3)(8,)
Image_batch是形狀的張量(8, 224, 224, 3)。這是一批形狀240x240x3的8張圖片(最後一維指的是彩色通道RGB)。
Label_batch是形狀(8,)的張量,這些標籤對應8張圖片
4. 配置數據集
shuffle() :打亂數據,關於此函數的詳細介紹可以參考:https://zhuanlan.zhihu.com/p/42417456
prefetch() :預取數據,加速運行,其詳細介紹可以參考我前兩篇文章,裡面都有講解。
cache() :將數據集緩存到內存當中,加速運行
AUTOTUNE=tf.data.AUTOTUNEtrain_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)三、殘差網絡(ResNet)介紹1. 殘差網絡解決了什麼

殘差網絡是為了解決神經網絡隱藏層過多時,而引起的網絡退化問題。退化(degradation)問題是指:當網絡隱藏層變多時,網絡的準確度達到飽和然後急劇退化,而且這個退化不是由於過擬合引起的。

拓展: 深度神經網絡的「兩朵烏雲」

梯度彌散/爆炸

簡單來講就是網絡太深了,會導致模型訓練難以收斂。這個問題可以被標準初始化和中間層正規化的方法有效控制。(現階段知道這麼一回事就好了)

網絡退化

隨着網絡深度增加,網絡的表現先是逐漸增加至飽和,然後迅速下降,這個退化不是由於過擬合引起的。

2. ResNet-50介紹

ResNet-50有兩個基本的塊,分別名為Conv Block和Identity Block

Conv Block結構:

Identity Block結構:

ResNet-50總體結構:

四、構建ResNet-50網絡模型

下面是本文的重點,可以試着按照上面三張圖自己構建一下ResNet-50

fromkerasimportlayersfromkeras.layersimportInput,Activation,BatchNormalization,Flattenfromkeras.layersimportDense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2Dfromkeras.modelsimportModeldefidentity_block(input_tensor,kernel_size,filters,stage,block):filters1,filters2,filters3=filtersname_base=str(stage)+block+'_identity_block_'x=Conv2D(filters1,(1,1),name=name_base+'conv1')(input_tensor)x=BatchNormalization(name=name_base+'bn1')(x)x=Activation('relu',name=name_base+'relu1')(x)x=Conv2D(filters2,kernel_size,padding='same',name=name_base+'conv2')(x)x=BatchNormalization(name=name_base+'bn2')(x)x=Activation('relu',name=name_base+'relu2')(x)x=Conv2D(filters3,(1,1),name=name_base+'conv3')(x)x=BatchNormalization(name=name_base+'bn3')(x)x=layers.add([x,input_tensor],name=name_base+'add')x=Activation('relu',name=name_base+'relu4')(x)returnxdefconv_block(input_tensor,kernel_size,filters,stage,block,strides=(2,2)):filters1,filters2,filters3=filtersres_name_base=str(stage)+block+'_conv_block_res_'name_base=str(stage)+block+'_conv_block_'x=Conv2D(filters1,(1,1),strides=strides,name=name_base+'conv1')(input_tensor)x=BatchNormalization(name=name_base+'bn1')(x)x=Activation('relu',name=name_base+'relu1')(x)x=Conv2D(filters2,kernel_size,padding='same',name=name_base+'conv2')(x)x=BatchNormalization(name=name_base+'bn2')(x)x=Activation('relu',name=name_base+'relu2')(x)x=Conv2D(filters3,(1,1),name=name_base+'conv3')(x)x=BatchNormalization(name=name_base+'bn3')(x)shortcut=Conv2D(filters3,(1,1),strides=strides,name=res_name_base+'conv')(input_tensor)shortcut=BatchNormalization(name=res_name_base+'bn')(shortcut)x=layers.add([x,shortcut],name=name_base+'add')x=Activation('relu',name=name_base+'relu4')(x)returnxdefResNet50(input_shape=[224,224,3],classes=1000):img_input=Input(shape=input_shape)x=ZeroPadding2D((3,3))(img_input)x=Conv2D(64,(7,7),strides=(2,2),name='conv1')(x)x=BatchNormalization(name='bn_conv1')(x)x=Activation('relu')(x)x=MaxPooling2D((3,3),strides=(2,2))(x)x=conv_block(x,3,[64,64,256],stage=2,block='a',strides=(1,1))x=identity_block(x,3,[64,64,256],stage=2,block='b')x=identity_block(x,3,[64,64,256],stage=2,block='c')x=conv_block(x,3,[128,128,512],stage=3,block='a')x=identity_block(x,3,[128,128,512],stage=3,block='b')x=identity_block(x,3,[128,128,512],stage=3,block='c')x=identity_block(x,3,[128,128,512],stage=3,block='d')x=conv_block(x,3,[256,256,1024],stage=4,block='a')x=identity_block(x,3,[256,256,1024],stage=4,block='b')x=identity_block(x,3,[256,256,1024],stage=4,block='c')x=identity_block(x,3,[256,256,1024],stage=4,block='d')x=identity_block(x,3,[256,256,1024],stage=4,block='e')x=identity_block(x,3,[256,256,1024],stage=4,block='f')x=conv_block(x,3,[512,512,2048],stage=5,block='a')x=identity_block(x,3,[512,512,2048],stage=5,block='b')x=identity_block(x,3,[512,512,2048],stage=5,block='c')x=AveragePooling2D((7,7),name='avg_pool')(x)x=Flatten()(x)x=Dense(classes,activation='softmax',name='fc1000')(x)model=Model(img_input,x,name='resnet50')#加載預訓練模型model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")returnmodelmodel=ResNet50()model.summary()Model: "resnet50"__________________________________________________________________________________________________Layer (type) Output Shape Param # Connected to ==================================================================================================input_1 (InputLayer) [(None, 224, 224, 3) 0 __________________________________________________________________________________________________zero_padding2d (ZeroPadding2D) (None, 230, 230, 3) 0 input_1[0][0] __________________________________________________________________________________________________conv1 (Conv2D) (None, 112, 112, 64) 9472 zero_padding2d[0][0] __________________________________________________________________________________________________bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0] __________________________________________________________________________________________________activation (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0] __________________________________________________________________________________________________max_pooling2d (MaxPooling2D) (None, 55, 55, 64) 0 activation[0][0] __________________________________________________________________________________________________2a_conv_block_conv1 (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d[0][0] __________________________________________________________________________________________________2a_conv_block_bn1 (BatchNormali (None, 55, 55, 64) 256 2a_conv_block_conv1[0][0] __________________________________________________________________________________________________2a_conv_block_relu1 (Activation (None, 55, 55, 64) 0 2a_conv_block_bn1[0][0] __________________________________________________________________________________________________2a_conv_block_conv2 (Conv2D) (None, 55, 55, 64) 36928 2a_conv_block_relu1[0][0] __________________________________________________________________________________________________2a_conv_block_bn2 (BatchNormali (None, 55, 55, 64) 256 2a_conv_block_conv2[0][0] __________________________________________________________________________________________________2a_conv_block_relu2 (Activation (None, 55, 55, 64) 0 2a_conv_block_bn2[0][0] __________________________________________________________________________________________________2a_conv_block_conv3 (Conv2D) (None, 55, 55, 256) 16640 2a_conv_block_relu2[0][0] __________________________________________________________________________________________________2a_conv_block_res_conv (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d[0][0] __________________________________________________________________________________________________2a_conv_block_bn3 (BatchNormali (None, 55, 55, 256) 1024 2a_conv_block_conv3[0][0] __________________________________________________________________________________________________2a_conv_block_res_bn (BatchNorm (None, 55, 55, 256) 1024 2a_conv_block_res_conv[0][0] __________________________________________________________________________________________________2a_conv_block_add (Add) (None, 55, 55, 256) 0 2a_conv_block_bn3[0][0] 2a_conv_block_res_bn[0][0] __________________________________________________________________________________________________2a_conv_block_relu4 (Activation (None, 55, 55, 256) 0 2a_conv_block_add[0][0] __________________________________________________________________________________________________2b_identity_block_conv1 (Conv2D (None, 55, 55, 64) 16448 2a_conv_block_relu4[0][0] __________________________________________________________________________________________________2b_identity_block_bn1 (BatchNor (None, 55, 55, 64) 256 2b_identity_block_conv1[0][0] ============================================================= 此處省略了若干行,此處省略了若干行,此處省略了若干行 =============================================================__________________________________________________________________________________________________5c_identity_block_relu2 (Activa (None, 7, 7, 512) 0 5c_identity_block_bn2[0][0] __________________________________________________________________________________________________5c_identity_block_conv3 (Conv2D (None, 7, 7, 2048) 1050624 5c_identity_block_relu2[0][0] __________________________________________________________________________________________________5c_identity_block_bn3 (BatchNor (None, 7, 7, 2048) 8192 5c_identity_block_conv3[0][0] __________________________________________________________________________________________________5c_identity_block_add (Add) (None, 7, 7, 2048) 0 5c_identity_block_bn3[0][0] 5b_identity_block_relu4[0][0] __________________________________________________________________________________________________5c_identity_block_relu4 (Activa (None, 7, 7, 2048) 0 5c_identity_block_add[0][0] __________________________________________________________________________________________________avg_pool (AveragePooling2D) (None, 1, 1, 2048) 0 5c_identity_block_relu4[0][0] __________________________________________________________________________________________________flatten (Flatten) (None, 2048) 0 avg_pool[0][0] __________________________________________________________________________________________________fc1000 (Dense) (None, 1000) 2049000 flatten[0][0] ==================================================================================================Total params: 25,636,712Trainable params: 25,583,592Non-trainable params: 53,120__________________________________________________________________________________________________五、編譯

在準備對模型進行訓練之前,還需要再對其進行一些設置。以下內容是在模型的編譯步驟中添加的:

損失函數(loss):用于衡量模型在訓練期間的準確率。
優化器(optimizer):決定模型如何根據其看到的數據和自身的損失函數進行更新。
指標(metrics):用於監控訓練和測試步驟。以下示例使用了準確率,即被正確分類的圖像的比率。
#設置優化器,我這裡改變了學習率。opt=tf.keras.optimizers.Adam(learning_rate=1e-7)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])六、訓練模型epochs=10history=model.fit(train_ds,validation_data=val_ds,epochs=epochs)Epoch 1/1057/57 [==============================] - 12s 86ms/step - loss: 2.4313 - accuracy: 0.6548 - val_loss: 213.7383 - val_accuracy: 0.3186Epoch 2/1057/57 [==============================] - 3s 52ms/step - loss: 0.4293 - accuracy: 0.8557 - val_loss: 9.0470 - val_accuracy: 0.2566Epoch 3/1057/57 [==============================] - 3s 52ms/step - loss: 0.2309 - accuracy: 0.9183 - val_loss: 1.4181 - val_accuracy: 0.7080Epoch 4/1057/57 [==============================] - 3s 53ms/step - loss: 0.1721 - accuracy: 0.9535 - val_loss: 2.5627 - val_accuracy: 0.6726Epoch 5/1057/57 [==============================] - 3s 53ms/step - loss: 0.0795 - accuracy: 0.9701 - val_loss: 0.2747 - val_accuracy: 0.8938Epoch 6/1057/57 [==============================] - 3s 52ms/step - loss: 0.0435 - accuracy: 0.9899 - val_loss: 0.1483 - val_accuracy: 0.9381Epoch 7/1057/57 [==============================] - 3s 52ms/step - loss: 0.0308 - accuracy: 0.9970 - val_loss: 0.1705 - val_accuracy: 0.9381Epoch 8/1057/57 [==============================] - 3s 52ms/step - loss: 0.0019 - accuracy: 1.0000 - val_loss: 0.0674 - val_accuracy: 0.9735Epoch 9/1057/57 [==============================] - 3s 52ms/step - loss: 8.2391e-04 - accuracy: 1.0000 - val_loss: 0.0720 - val_accuracy: 0.9735Epoch 10/1057/57 [==============================] - 3s 52ms/step - loss: 6.0079e-04 - accuracy: 1.0000 - val_loss: 0.0762 - val_accuracy: 0.9646七、模型評估acc=history.history['accuracy']val_acc=history.history['val_accuracy']loss=history.history['loss']val_loss=history.history['val_loss']epochs_range=range(epochs)plt.figure(figsize=(12,4))plt.subplot(1,2,1)plt.suptitle("微信公眾號:K同學啊")plt.plot(epochs_range,acc,label='TrainingAccuracy')plt.plot(epochs_range,val_acc,label='ValidationAccuracy')plt.legend(loc='lowerright')plt.title('TrainingandValidationAccuracy')plt.subplot(1,2,2)plt.plot(epochs_range,loss,label='TrainingLoss')plt.plot(epochs_range,val_loss,label='ValidationLoss')plt.legend(loc='upperright')plt.title('TrainingandValidationLoss')plt.show()八、保存and加載模型

這是最簡單的模型保存與加載方法哈

#保存模型model.save('model/my_model.h5')#加載模型new_model=keras.models.load_model('model/my_model.h5')九、預測#採用加載的模型(new_model)來看預測結果plt.figure(figsize=(10,5))#圖形的寬為10高為5plt.suptitle("微信公眾號:K同學啊")forimages,labelsinval_ds.take(1):foriinrange(8):ax=plt.subplot(2,4,i+1)#顯示圖片plt.imshow(images[i].numpy().astype("uint8"))#需要給圖片增加一個維度img_array=tf.expand_dims(images[i],0)#使用模型預測圖片中的人物predictions=new_model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

如果覺得本文對你有幫助記得點個在看,給個贊,加個收藏

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()