|
| 1 | +import math |
| 2 | +from copy import deepcopy |
| 3 | + |
| 4 | + |
| 5 | +class MaxEntropy: |
| 6 | + def __init__(self, EPS=0.005): |
| 7 | + self._samples = [] |
| 8 | + self._Y = set() # 标签集合,相当去去重后的y |
| 9 | + self._numXY = {} # key为(x,y),value为出现次数 |
| 10 | + self._N = 0 # 样本数 |
| 11 | + self._Ep_ = [] # 样本分布的特征期望值 |
| 12 | + self._xyID = {} # key记录(x,y),value记录id号 |
| 13 | + self._n = 0 # 特征键值(x,y)的个数 |
| 14 | + self._C = 0 # 最大特征数 |
| 15 | + self._IDxy = {} # key为(x,y),value为对应的id号 |
| 16 | + self._w = [] |
| 17 | + self._EPS = EPS # 收敛条件 |
| 18 | + self._lastw = [] # 上一次w参数值 |
| 19 | + |
| 20 | + def loadData(self, dataset): |
| 21 | + self._samples = deepcopy(dataset) |
| 22 | + for items in self._samples: |
| 23 | + y = items[0] |
| 24 | + X = items[1:] |
| 25 | + self._Y.add(y) # 集合中y若已存在则会自动忽略 |
| 26 | + for x in X: |
| 27 | + if (x, y) in self._numXY: |
| 28 | + self._numXY[(x, y)] += 1 |
| 29 | + else: |
| 30 | + self._numXY[(x, y)] = 1 |
| 31 | + |
| 32 | + self._N = len(self._samples) |
| 33 | + self._n = len(self._numXY) |
| 34 | + self._C = max([len(sample)-1 for sample in self._samples]) |
| 35 | + self._w = [0]*self._n |
| 36 | + self._lastw = self._w[:] |
| 37 | + |
| 38 | + self._Ep_ = [0] * self._n |
| 39 | + for i, xy in enumerate(self._numXY): # 计算特征函数fi关于经验分布的期望 |
| 40 | + self._Ep_[i] = self._numXY[xy]/self._N |
| 41 | + self._xyID[xy] = i |
| 42 | + self._IDxy[i] = xy |
| 43 | + |
| 44 | + def _Zx(self, X): # 计算每个Z(x)值 |
| 45 | + zx = 0 |
| 46 | + for y in self._Y: |
| 47 | + ss = 0 |
| 48 | + for x in X: |
| 49 | + if (x, y) in self._numXY: |
| 50 | + ss += self._w[self._xyID[(x, y)]] |
| 51 | + zx += math.exp(ss) |
| 52 | + return zx |
| 53 | + |
| 54 | + def _model_pyx(self, y, X): # 计算每个P(y|x) |
| 55 | + zx = self._Zx(X) |
| 56 | + ss = 0 |
| 57 | + for x in X: |
| 58 | + if (x, y) in self._numXY: |
| 59 | + ss += self._w[self._xyID[(x, y)]] |
| 60 | + pyx = math.exp(ss)/zx |
| 61 | + return pyx |
| 62 | + |
| 63 | + def _model_ep(self, index): # 计算特征函数fi关于模型的期望 |
| 64 | + x, y = self._IDxy[index] |
| 65 | + ep = 0 |
| 66 | + for sample in self._samples: |
| 67 | + if x not in sample: |
| 68 | + continue |
| 69 | + pyx = self._model_pyx(y, sample) |
| 70 | + ep += pyx/self._N |
| 71 | + return ep |
| 72 | + |
| 73 | + def _convergence(self): # 判断是否全部收敛 |
| 74 | + for last, now in zip(self._lastw, self._w): |
| 75 | + if abs(last - now) >= self._EPS: |
| 76 | + return False |
| 77 | + return True |
| 78 | + |
| 79 | + def predict(self, X): # 计算预测概率 |
| 80 | + Z = self._Zx(X) |
| 81 | + result = {} |
| 82 | + for y in self._Y: |
| 83 | + ss = 0 |
| 84 | + for x in X: |
| 85 | + if (x, y) in self._numXY: |
| 86 | + ss += self._w[self._xyID[(x, y)]] |
| 87 | + pyx = math.exp(ss)/Z |
| 88 | + result[y] = pyx |
| 89 | + return result |
| 90 | + |
| 91 | + def train(self, maxiter=1000): # 训练数据 |
| 92 | + for loop in range(maxiter): # 最大训练次数 |
| 93 | + print("iter:%d" % loop) |
| 94 | + self._lastw = self._w[:] |
| 95 | + for i in range(self._n): |
| 96 | + ep = self._model_ep(i) # 计算第i个特征的模型期望 |
| 97 | + self._w[i] += math.log(self._Ep_[i]/ep)/self._C # 更新参数 |
| 98 | + print("w:", self._w) |
| 99 | + if self._convergence(): # 判断是否收敛 |
| 100 | + break |
| 101 | + |
| 102 | +if __name__ == "__main__": |
| 103 | + dataset = [['no', 'sunny', 'hot', 'high', 'FALSE'], |
| 104 | + ['no', 'sunny', 'hot', 'high', 'TRUE'], |
| 105 | + ['yes', 'overcast', 'hot', 'high', 'FALSE'], |
| 106 | + ['yes', 'rainy', 'mild', 'high', 'FALSE'], |
| 107 | + ['yes', 'rainy', 'cool', 'normal', 'FALSE'], |
| 108 | + ['no', 'rainy', 'cool', 'normal', 'TRUE'], |
| 109 | + ['yes', 'overcast', 'cool', 'normal', 'TRUE'], |
| 110 | + ['no', 'sunny', 'mild', 'high', 'FALSE'], |
| 111 | + ['yes', 'sunny', 'cool', 'normal', 'FALSE'], |
| 112 | + ['yes', 'rainy', 'mild', 'normal', 'FALSE'], |
| 113 | + ['yes', 'sunny', 'mild', 'normal', 'TRUE'], |
| 114 | + ['yes', 'overcast', 'mild', 'high', 'TRUE'], |
| 115 | + ['yes', 'overcast', 'hot', 'normal', 'FALSE'], |
| 116 | + ['no', 'rainy', 'mild', 'high', 'TRUE']] |
| 117 | + |
| 118 | + maxent = MaxEntropy() |
| 119 | + x = ['overcast', 'mild', 'high', 'FALSE'] |
| 120 | + maxent.loadData(dataset) |
| 121 | + maxent.train() |
| 122 | + print('predict:', maxent.predict(x)) |
0 commit comments