完整代码
#-*-coding:utf-8-*-
import pandas as pd
import math
from collections import defaultdict
# load the data and preprocess the data
train = pd.read_csv("./data/train.txt")
test = pd.read_csv("./data/test.txt")
def loadData():
# divide the data into two parts female and male
names_male = train[train['gender'] == 0]
names_female = train[train['gender'] == 1]
totals = {
'f':len(names_female),
'm':len(names_male),
}
# use total to storage the oss
return names_male,names_female,totals
# cal the posibilitied of the word in the name
def calFreq(names_male,names_female,totals):
# the word appereanced in female's name
freq_list_f = defaultdict(int)
for name in names_female :
for char in name:
freq_list_f[char] += 1.0 / totals['f']
# the word appereanced in female's name
freq_list_m = defaultdict(int)
for name in names_male :
for char in name:
freq_list_f[char] += 1.0 / totals['m']
return freq_list_m, freq_list_f
# to avoid some word not disapperenced in the train data
def LaplaceSmooth(char, freq_list,total,alpha=1.0):
count = freq_list[char * total]
distinct_chars = len(freq_list)
freq_smooth = (count+alpha)/(total+ distinct_chars * alpha)
return freq_smooth
## ??
def GetLogProb(char, frequency_list, total):
freq_smooth = LaplaceSmooth(char, frequency_list, total)
return math.log(freq_smooth) - math.log(1 - freq_smooth)
def getBase(freq_list_m,freq_list_f,train):
base_f = math.log(1 - train['gender'].mean())
base_f += sum([math.log(1 - freq_list_f[char]) for char in freq_list_f])
base_m = math.log(train['gender'].mean())
base_m += sum([math.log(1 - freq_list_m[char]) for char in freq_list_m])
bases = {'f': base_f, 'm': base_m}
return bases
def calLogProb(name, bases,totals, freq_list_m,freq_list_f):
logprob_m = bases['m']
logprob_f = bases['f']
for char in name:
logprob_m += GetLogProb(char,freq_list_m,totals['m'])
logprob_f += GetLogProb(char,freq_list_f,totals['f'])
return {'male':logprob_m,'female':logprob_f}
def getGender(logProbs):
return logProbs['male'] > logProbs['female']
def getResult(bases, totals, freq_list_m, freq_list_f):
result = []
for name in test['name']:
LogProbs = calLogProb(name, bases, totals, freq_list_m, freq_list_f)
gender = getGender(LogProbs)
result.append(int(gender))
test['pred'] = result
print(test.head(20))
return result
def main():
names_male,names_female,totals = loadData()
freq_list_m, freq_list_f = calFreq(names_male,names_female,totals)
base = getBase(freq_list_m,freq_list_f,train)
result = getResult(base, totals, freq_list_m, freq_list_f)
main()