我似乎无法弄清楚如何使线性回归线(又名最佳拟合线)跨越图形的整个宽度。它似乎只是上升了左边最远的数据点和右边最远的数据点,没有进一步。我将如何解决这个问题?
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import *
import MySQLdb
# connect to MySQL database
def mysql_select_all():
conn = MySQLdb.connect(host='localhost',
user='root',
passwd='XXXXX',
db='world')
cursor = conn.cursor()
sql = """
SELECT
GNP, Population
FROM
country
WHERE
Name LIKE 'United States'
OR Name LIKE 'Canada'
OR Name LIKE 'United Kingdom'
OR Name LIKE 'Russia'
OR Name LIKE 'Germany'
OR Name LIKE 'Poland'
OR Name LIKE 'Italy'
OR Name LIKE 'China'
OR Name LIKE 'India'
OR Name LIKE 'Japan'
OR Name LIKE 'Brazil';
"""
cursor.execute(sql)
result = cursor.fetchall()
list_x = []
list_y = []
for row in result:
list_x.append(('%r' % (row[0],)))
for row in result:
list_y.append(('%r' % (row[1],)))
list_x = list(map(float, list_x))
list_y = list(map(float, list_y))
fig = plt.figure()
ax1 = plt.subplot2grid((1,1), (0,0))
p1 = np.polyfit(list_x, list_y, 1) # this line refers to line of regression
ax1.xaxis.labelpad = 50
ax1.yaxis.labelpad = 50
plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression
plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)
plt.xlabel("GNP (US dollars)", fontsize=30)
plt.ylabel("Population(in billions)", fontsize=30)
plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000,
7000000, 8000000, 9000000], rotation=45, fontsize=14)
plt.yticks(fontsize=14)
plt.show()
cursor.close()
mysql_select_all()
米琪卡哇伊
隔江千里
相关分类