40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
def logistic_function(x, w, b):
|
|
return 1 / (1 + np.exp(-(w * x + b)))
|
|
|
|
def cost_function(f_wb, y):
|
|
if y == 1:
|
|
return -np.log(f_wb)
|
|
else:
|
|
return -np.log(1 - f_wb)
|
|
|
|
f_wb = np.linspace(-1, 2, 100) # 避免log(0)的情况
|
|
|
|
cost_y1 = cost_function(f_wb, 1)
|
|
cost_y0 = cost_function(f_wb, 0)
|
|
|
|
plt.figure(figsize=(12, 6),dpi=600)
|
|
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(f_wb, cost_y1, label='y=1')
|
|
plt.title('Cost Function when y=1')
|
|
plt.xlabel('f_wb')
|
|
plt.ylabel('Cost')
|
|
plt.axhline(0, color='black', linewidth=0.8) # 增加水平坐标轴
|
|
plt.axvline(0, color='black', linewidth=0.8) # 增加垂直坐标轴
|
|
plt.legend()
|
|
|
|
plt.subplot(1, 2, 2)
|
|
plt.plot(f_wb, cost_y0, label='y=0')
|
|
plt.title('Cost Function when y=0')
|
|
plt.xlabel('f_wb')
|
|
plt.ylabel('Cost')
|
|
plt.axhline(0, color='black', linewidth=0.8) # 增加水平坐标轴
|
|
plt.axvline(0, color='black', linewidth=0.8) # 增加垂直坐标轴
|
|
plt.legend()
|
|
|
|
plt.tight_layout()
|
|
#plt.show()
|
|
plt.savefig('plt/logistic_cost_function.png') |