代码:
1 import numpy as np 2 from scipy import fft 3 from scipy.io import wavfile 4 from sklearn.linear_model import LogisticRegression 5 import random 6 """ 7 使用logistic regression处理音乐数据,音乐数据训练样本的获得是使用快速傅里叶变换: 8 1.吧训练集扩大到每类100首歌二部之前的10首歌 9 2.同时使用logistic回归和KNN作为分类器 10 3.引入一些评价标准来比较logictic和KNN在测试集上的表现 11 """ 12 """ 13 n = 40 14 # hstack数据拼接 15 # 在模拟X的时候使用了两个正态分布,分别制定各自的均值,方差,生成40个点 16 X = np.hstack((norm.rvs(loc=2, size=n, scale=2), norm.rvs(loc=8, size=n, scale=3))) 17 # zeros使得数据点生成40个0,ones使得数据点生成40个1 18 y = np.hstack((np.zeros(n), np.ones(n))) 19 20 """ 21 general_list = ["classical", "jazz", "country", "pop", "rock", "metal"] 22 """ 23 # 处理原始数据,转化为特征文件 24 def create_fft(g, n): 25 read = "d:/xxx/"+g+"/conberted/"+g+"."+str(n).zfill(5)+".au.wav" 26 sample_rate, x = wavfile.read(read) 27 fft_features = abs(fft(x)[:1000]) 28 sad = "d:/trainset/"+g+"."+str(n).zfill(5)+".fft" 29 np.save(sad, fft_features) 30 31 32 for g in gener_list: 33 for n in range(100): 34 create_fft(g, n) 35 """ 36 # 加载训练集数据,分割训练集和测试集,进行分类器的训练 37 # 构造训练集 38 x = [] 39 y = [] 40 for g in general_list: 41 for n in range(100): 42 read = "D:/AnalyseData学习资源库/人工智能开发【中】/05_分类器项目案例和神经网络算法【尚学堂·百战程序员】/资料/trainset/"+g+"."+str(n).zfill(5)+".fft"+".npy" 43 fft_features = np.load(read) 44 x.append(fft_features) 45 y.append(general_list.index(g)) 46 47 x = np.array(x) 48 y = np.array(y) 49 50 # 拆分数据为训练集和测试集 51 randomIndex = random.sample(range(len(y)), int(len(y)*8/10)) 52 trainX = [] 53 trainY = [] 54 testX = [] 55 testY = [] 56 57 for i in range(len(y)): 58 if i in randomIndex: 59 trainX.append(x[i]) 60 trainY.append(y[i]) 61 else: 62 testX.append(x[i]) 63 testY.append(y[i]) 64 65 # 使用sklearn来构建和训练两种分类器 66 67 # logistic classifier 68 model = LogisticRegression() 69 # train 70 model.fit(trainX, trainY) 71 # test 72 predict = model.predict(testX) 73 print(testY) 74 print("--------------------------------------------------") 75 print(predict) 76 # error 77 import math 78 error = 0.0 79 for i in range(len(testY)): 80 if testY[i] != predict[i]: 81 error = error + 1 82 print(error/len(testY)) 83 84 print("starting read wavfile...") 85 sample_rate, test = wavfile.read("D:/AnalyseData学习资源库/人工智能开发【中】/05_分类器项目案例和神经网络算法【尚学堂·百战程序员】/资料/trainset/sample/heibao-wudizirong-remix.wav") 86 testdata_fft_features = abs(fft(test))[:1000] 87 type_index = model.predict([testdata_fft_features])[0] 88 89 print("预测音乐分类为:"+general_list[type_index])
结果: