用tf2做图片分类

# -*- coding: utf-8 -*-
# @File : train.py
# @desc : 训练模型

import tensorflow as tf
import numpy as np
import glob, os

modelname = ‘kcEnterGame.h5′
trainimgpath = r’E:\AI_ND_YINHUN\KuaiCeIsEnterGame\trainIMG\*’
BATCH_SIZE = 32
eps = 100
label_dir = glob.glob(trainimgpath)
all_image_path = glob.glob(trainimgpath + ‘\*.jpg’)
all_labels = {}
for index in range(len(label_dir)):
all_labels[index] = label_dir[index].split(‘\\’)[-1]
index_to_label = dict((v,k) for k,v in all_labels.items())
def load_and_preprocess_img(path):
label = path.split(‘\\’)[-2]
img_raw = tf.io.read_file(path) # 读取图片为二进制
img_tensor = tf.image.decode_jpeg(img_raw, channels=3) # 解码为张量
img_tensor = tf.image.resize(img_tensor, [299, 299]) # 图片大小缩放
img_tensor = tf.cast(img_tensor, tf.float32) # 转换图片数字类型从uint8转为float32便于计算
img_tensor = img_tensor / 255 # 图片数据归一化
return img_tensor, index_to_label[label]
train_batch_count = len(all_image_path) // BATCH_SIZE
print(‘train_batch_count:’+str(train_batch_count))
def get_detail_data():
train_data = []
while True:
for i in range(train_batch_count):
all_image_list = []
all_label_list = []
for item in all_image_path[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]:
img, label = load_and_preprocess_img(item)
all_image_list.append(img)
all_label_list.append(label)
yield np.array(all_image_list),np.array(all_label_list)

if not os.path.exists(modelname):
mobile_ls = tf.keras.applications.Xception(input_shape=(299, 299, 3), include_top=False, weights=’imagenet’)
mobile_ls.trainable = False
x_inputs = [tf.keras.Input(shape=(299, 299, 3))]
x1 = mobile_ls(x_inputs[0])
x1 = tf.keras.layers.GlobalAveragePooling2D()(x1)
x1 = tf.keras.layers.Dense(1024, activation=’relu’)(x1)
out_tx= tf.keras.layers.Dense(len(all_labels), activation=’softmax’, name=’out_tx’)(x1)
model = tf.keras.Model(inputs=x_inputs, outputs=[out_tx])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss={‘out_tx’: ‘sparse_categorical_crossentropy’},
metrics=[‘acc’]
)
else:
model = tf.keras.models.load_model(modelname)

model.fit(get_detail_data(), batch_size=BATCH_SIZE, epochs=eps, steps_per_epoch=train_batch_count)
model.save(modelname)

 

###############################################################################
# -*- coding: utf-8 -*-
# @File : test.py
# @desc : 分类:游戏内,游戏外

import tensorflow as tf
import numpy as np
import glob, os

modelname = ‘kcEnterGame.h5′
testmgpath = r’E:\AI_ND_YINHUN\KuaiCeIsEnterGame\testIMG4\*’
all_image_path = glob.glob(testmgpath + ‘\*.jpg’)
model = tf.keras.models.load_model(modelname)
all_labels = {}
label_dir = glob.glob(r’E:\AI_ND_YINHUN\KuaiCeIsEnterGame\trainIMG\*’)
for index in range(len(label_dir)):
all_labels[index] = label_dir[index].split(‘\\’)[-1]
def load_test_img(path):
img_raw = tf.io.read_file(path) # 读取图片为二进制
img_tensor = tf.image.decode_jpeg(img_raw, channels=3) # 解码为张量
img_tensor = tf.image.resize(img_tensor, [299, 299]) # 图片大小缩放
img_tensor = tf.cast(img_tensor, tf.float32) # 转换图片数字类型从uint8转为float32便于计算
img_tensor = img_tensor / 255 # 图片数据归一化
return img_tensor
output_str = ”
for item in all_image_path:
testimg = load_test_img(item)
testimg = tf.expand_dims(testimg,0)
pred = model.predict(testimg)
if np.max(pred) > 0.6:
# output_str = output_str + item[-9:-4] + ‘,’ + all_labels[np.argmax(pred)] + ‘,’ + str(np.max(pred))+ ‘,OK’ + ‘\n’
# print(item[-9:-4] + ‘,’ + all_labels[np.argmax(pred)] + ‘,’ + str(np.max(pred))+ ‘,游戏内’)
print(np.argmax(pred),all_labels[np.argmax(pred)], item, str(np.max(pred)))
# if all_labels[np.argmax(pred)]==’PASS’:
# print(‘在游戏内’, all_labels[np.argmax(pred)], item, str(np.max(pred)))
# else:
# print(‘未进入游戏’, all_labels[np.argmax(pred)], item, str(np.max(pred)))
else:
# output_str = output_str + item[-9:-4] + ‘,游戏外,’ + str(np.max(pred)) + ‘,’ + all_labels[np.argmax(pred)] + ‘\n’
# print(item[-9:-4] + ‘,游戏外,’ + str(np.max(pred)) + ‘,’ + all_labels[np.argmax(pred)])
print(‘无法分辨图片:’, item, str(np.max(pred)))

# file = open(‘./csvTestResult/test.csv’, ‘w’)
# file.writelines(output_str)
# file.close()

Comments

No comments yet. Why don’t you start the discussion?

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注