博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
神经网络入门(电影评论分类--------二分类问题)
阅读量:4961 次
发布时间:2019-06-12

本文共 2722 字,大约阅读时间需要 9 分钟。

IMDB数据集

from keras.datasets import imdb (train_data,train_labels),(test_data,test_labels)=imdb.load_data(num_words=10000) print(train_data[0]) print(train_labels[0]) print(max([max(sequence) for sequence in train_data])) word_index=imdb.get_word_index() reverse_word_index=dict(     [(value,key) for (key,value) in word_index.items()] ) decoded_review=' '.join([reverse_word_index.get(i-3,'?') for i in train_data[0]]) print(decoded_review) #将整数序列编码为二进制矩阵 import numpy as np def vectorize_sequences(sequences,dimension=10000):     results=np.zeros((len(sequences),dimension))     for i,sequence in enumerate(sequences):         results[i,sequence]=1     return results x_train=vectorize_sequences(train_data) x_test=vectorize_sequences(test_data) print(x_train[0]) y_train=np.asarray(train_labels).astype('float32') y_test=np.asarray(test_labels).astype('float32') ####模型定义##### from keras import models from keras import layers model=models.Sequential() model.add(layers.Dense(16,activation='relu',input_shape=(10000,))) model.add(layers.Dense(16,activation='relu')) model.add(layers.Dense(1,activation='sigmoid')) ####模型编译#### model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy']) #####配置优化器##### from keras import optimizers model.compile(optimizer=optimizers.RMSprop(lr=0.001),loss='binary_crossentropy',metrics=['accuracy']) ####使用自定义的损失和指标### from keras import losses from keras import metrics model.compile(optimizer=optimizers.RMSprop(lr=0.001),loss=losses.binary_crossentropy,metrics=[metrics.binary_accuracy]) #####留出验证集###### x_val=x_train[:10000] partial_x_train=x_train[10000:] y_val=y_train[:10000] partial_y_train=y_train[10000:] #####训练模型####### model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['acc']) history=model.fit(partial_x_train,partial_y_train,epochs=20,batch_size=512,validation_data=(x_val,y_val)) history_dict=history.history print(history_dict.keys()) #####绘制训练损失和验证损失#### import matplotlib.pyplot as plt history_dict=history.history loss_values=history_dict['loss'] val_loss_values=history_dict['val_loss'] epochs=range(1,len(loss_values)+1) plt.plot(epochs,loss_values,'bo',label='Training loss')  ###'bo'表示蓝色圆点 plt.plot(epochs,val_loss_values,'b',label='Validation loss') plt.title('Training and validation loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.show() ######绘制训练精度和验证精度 plt.clf() acc=history_dict['acc'] val_acc=history_dict['val_acc'] plt.plot(epochs,acc,'bo',label='Training acc')  ###'bo'表示蓝色圆点 plt.plot(epochs,val_acc,'b',label='Validation acc') plt.title('Training and validation accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend() plt.show()

转载于:https://www.cnblogs.com/yxllfl/p/11521827.html

你可能感兴趣的文章
[QA]UrlRewriter无法解析实际存在的htm文件
查看>>
记一次因为索引维护导致批量无法继续的情况
查看>>
poj 2195 (最小费用最大流)
查看>>
HCA数据下载
查看>>
Codeforces 954 G. Castle Defense
查看>>
反射机制-----------通过它获取类中所有东西 出了注释
查看>>
svn的一个连接
查看>>
position:fixed和z-index:1
查看>>
unity, 延迟执行代码
查看>>
mysq找不到pid无法正常启动
查看>>
php实现抓取网站百度快照和百度收录数量的代码实例
查看>>
Qt那点事儿(三) 论父对象与子对象的关系
查看>>
jar 命令 打包装class文件的文件夹
查看>>
node.js express配置允许跨域
查看>>
JSP EL表达式详细介绍(转)
查看>>
要想找出正好包含5个字符的名字
查看>>
用js把图片做的富有动态感,并对以后需要用着的属性进行封装
查看>>
ArcGIS Runtime For Android 100.3天地图不加载问题
查看>>
线性表
查看>>
【转】解决eclipse新导入工程无法run as server
查看>>