Matplotlib - 将线性回归线扩展到图的整个宽度

我似乎无法弄清楚如何使线性回归线(又名最佳拟合线)跨越图形的整个宽度。它似乎只是上升了左边最远的数据点和右边最远的数据点,没有进一步。我将如何解决这个问题?


import matplotlib.pyplot as plt

import numpy as np

from scipy import stats

from scipy.interpolate import *

import MySQLdb


# connect to MySQL database

def mysql_select_all():

    conn = MySQLdb.connect(host='localhost',

                           user='root',

                           passwd='XXXXX',

                           db='world')

    cursor = conn.cursor()

    sql = """

        SELECT

            GNP, Population

        FROM

            country

        WHERE

            Name LIKE 'United States'

                OR Name LIKE 'Canada'

                OR Name LIKE 'United Kingdom'

                OR Name LIKE 'Russia'

                OR Name LIKE 'Germany'

                OR Name LIKE 'Poland'

                OR Name LIKE 'Italy'

                OR Name LIKE 'China'

                OR Name LIKE 'India'

                OR Name LIKE 'Japan'

                OR Name LIKE 'Brazil';

    """


    cursor.execute(sql)

    result = cursor.fetchall()


    list_x = []

    list_y = []


    for row in result:

        list_x.append(('%r' % (row[0],)))


    for row in result:

        list_y.append(('%r' % (row[1],)))


    list_x = list(map(float, list_x))

    list_y = list(map(float, list_y))


    fig = plt.figure()

    ax1 = plt.subplot2grid((1,1), (0,0))


    p1 = np.polyfit(list_x, list_y, 1)          # this line refers to line of regression


    ax1.xaxis.labelpad = 50

    ax1.yaxis.labelpad = 50


    plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression  

    plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)

    plt.xlabel("GNP (US dollars)", fontsize=30)

    plt.ylabel("Population(in billions)", fontsize=30)

    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 

                7000000, 8000000, 9000000],  rotation=45, fontsize=14)

    plt.yticks(fontsize=14)


    plt.show()

    cursor.close()


mysql_select_all()


神不在的星期二
浏览 344回答 3
3回答

米琪卡哇伊

我似乎无法弄清楚如何使线性回归线(又名最佳拟合线)跨越图形的整个宽度。它似乎只是上升了左边最远的数据点和右边最远的数据点,没有进一步。我将如何解决这个问题?import matplotlib.pyplot as pltimport numpy as npfrom scipy import statsfrom scipy.interpolate import *import MySQLdb# connect to MySQL databasedef mysql_select_all():    conn = MySQLdb.connect(host='localhost',                           user='root',                           passwd='XXXXX',                           db='world')    cursor = conn.cursor()    sql = """        SELECT            GNP, Population        FROM            country        WHERE            Name LIKE 'United States'                OR Name LIKE 'Canada'                OR Name LIKE 'United Kingdom'                OR Name LIKE 'Russia'                OR Name LIKE 'Germany'                OR Name LIKE 'Poland'                OR Name LIKE 'Italy'                OR Name LIKE 'China'                OR Name LIKE 'India'                OR Name LIKE 'Japan'                OR Name LIKE 'Brazil';    """    cursor.execute(sql)    result = cursor.fetchall()    list_x = []    list_y = []    for row in result:        list_x.append(('%r' % (row[0],)))    for row in result:        list_y.append(('%r' % (row[1],)))    list_x = list(map(float, list_x))    list_y = list(map(float, list_y))    fig = plt.figure()    ax1 = plt.subplot2grid((1,1), (0,0))    p1 = np.polyfit(list_x, list_y, 1)          # this line refers to line of regression    ax1.xaxis.labelpad = 50    ax1.yaxis.labelpad = 50    plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression      plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)    plt.xlabel("GNP (US dollars)", fontsize=30)    plt.ylabel("Population(in billions)", fontsize=30)    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000,                 7000000, 8000000, 9000000],  rotation=45, fontsize=14)    plt.yticks(fontsize=14)    plt.show()    cursor.close()mysql_select_all()而延长之后,

隔江千里

如果您希望绘图不超出 x 轴上的数据,只需执行以下操作:fig, ax = plt.subplots()ax.margins(x=0)# Don't use plt.plotax.plot(list_x, np.polyval(p1,list_x),'r-')ax.scatter(list_x, list_y, color = 'darkgreen', s = 100)ax.set_xlabel("GNP (US dollars)", fontsize=30)ax.set_ylabel("Population(in billions)", fontsize=30)ax.set_xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000],  rotation=45, fontsize=14)ax.tick_params(axis='y', labelsize=14)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python