checkpoint_save_path="./checkpoint/fashion.ckpt"#先定义出存放模型的路径和文件名,“.ckpt”文件在生成时会同步生成索引表ifos.path.exists(checkpoint_save_path+'.index'):#判断是否有索引表,就可以知道是否报存过模型,如果有索引表,就会调用load_weights()即模型print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)
importtensorflowastfimportos# 引入os模块,(文件处理)mnist=tf.keras.datasets.mnist(x_train,y_train),(x_test,y_test)=mnist.load_data()x_train,x_test=x_train/255.0,x_test/255.0model=tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax')])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path="./checkpoint/fashion.ckpt"ifos.path.exists(checkpoint_save_path+'.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history=model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])model.summary()
importtensorflowastfimportosimportnumpyasnpnp.set_printoptions(threshold=np.inf)mnist=tf.keras.datasets.mnist(x_train,y_train),(x_test,y_test)=mnist.load_data()x_train,x_test=x_train/255.0,x_test/255.0model=tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax')])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path="./checkpoint/fashion.ckpt"ifos.path.exists(checkpoint_save_path+'.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history=model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])model.summary()print(model.trainable_variables)file=open('./weights.txt','w')forvinmodel.trainable_variables:file.write(str(v.name)+'\n')file.write(str(v.shape)+'\n')file.write(str(v.numpy())+'\n')file.close()
importtensorflowastfimportosimportnumpyasnpfrommatplotlibimportpyplotasplt# 导入绘图模块np.set_printoptions(threshold=np.inf)mnist=tf.keras.datasets.mnist(x_train,y_train),(x_test,y_test)=mnist.load_data()x_train,x_test=x_train/255.0,x_test/255.0model=tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax')])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path="./checkpoint/fashion.ckpt"ifos.path.exists(checkpoint_save_path+'.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history=model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])model.summary()print(model.trainable_variables)file=open('./weights.txt','w')forvinmodel.trainable_variables:file.write(str(v.name)+'\n')file.write(str(v.shape)+'\n')file.write(str(v.numpy())+'\n')file.close()############################################### show ################################################ 显示训练集和验证集的acc和loss曲线acc=history.history['sparse_categorical_accuracy']val_acc=history.history['val_sparse_categorical_accuracy']loss=history.history['loss']val_loss=history.history['val_loss']plt.subplot(1,2,1)plt.plot(acc,label='Training Accuracy')plt.plot(val_acc,label='Validation Accuracy')plt.title('Training and Validation Accuracy')plt.legend()plt.subplot(1,2,2)plt.plot(loss,label='Training Loss')plt.plot(val_loss,label='Validation Loss')plt.title('Training and Validation Loss')plt.legend()plt.show()
fromPILimportImageimportnumpyasnpimporttensorflowastfmodel_save_path='./checkpoint/mnist.ckpt'model=tf.keras.models.Sequential([# 复现网络tf.keras.layers.Flatten(),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax')])model.load_weights(model_save_path)# 加载参数preNum=int(input("input the number of test pictures:"))# 准备预测多少个数foriinrange(preNum):# 读入待识别的图片image_path=input("the path of test picture:")img=Image.open(image_path)img=img.resize((28,28),Image.ANTIALIAS)# 转换成(28,28)的类型,与训练数据类型匹配img_arr=np.array(img.convert('L'))# 转换成灰度图img_arr=255-img_arr# 将“白底黑字”反转成“黑底白字”#####or###### for i in range(28): # 转换成高对比度的图,过滤噪声# for j in range(28):# if img_arr[i][j] < 200:# img_arr[i][j] = 255# else:# img_arr[i][j] = 0img_arr=img_arr/255.0# 归一化print("img_arr:",img_arr.shape)x_predict=img_arr[tf.newaxis,...]# 由于是按每个batch送入网络,故添加一个维度print("x_predict:",x_predict.shape)result=model.predict(x_predict)#预测结果pred=tf.argmax(result,axis=1)print('\n')tf.print(pred)