Linear Regression 실습

2024. 3. 12. 01:50·Machine Learning/Linear Regression

 

 

 
 

Linear Regression 실습¶

 
In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


np.random.seed(2021)
 
 

1. Univariate Regression¶

 
 

1.1 Sample Data¶

강의에서 예시로 사용했던 데이터를 생성합니다.

 
In [39]:
X = np.array([1,2,3,4])
y = np.array([2,1,4,3])
 
 

Plot으로 그려보겠습니다.

 
In [40]:
plt.scatter(X,y) 
 
Out[40]:
<matplotlib.collections.PathCollection at 0x2092ad030d0>
 
 
 

1.2 Data 변환¶

scikit-learn 에서 모델 학습을 위한 데이터는 (n,c) 형태로 되어 있어야 합니다.

  • n은 데이터의 개수를 의미합니다.
  • c는 feature의 개수를 의미합니다.

우리가 사용하는 데이터는 4개의 데이터와 1개의 feature로 이루어져 있습니다.

 
In [41]:
X
 
Out[41]:
array([1, 2, 3, 4])
 
In [42]:
X.shape
 
Out[42]:
(4,)
 
In [43]:
X.ravel()
 
Out[43]:
array([1, 2, 3, 4])
 
In [6]:
data = X.reshape(-1, 1)
 
In [7]:
data
 
Out[7]:
array([[1],
       [2],
       [3],
       [4]])
 
In [8]:
data.shape
 
Out[8]:
(4, 1)
 
 

1.3 Liner Regression¶

 
In [9]:
from sklearn.linear_model import LinearRegression
 
 

LinearRgression 을 model변수에 선언합니다.

 
In [10]:
model = LinearRegression()
 
 

1.3.1 학습하기¶

scikii-learn 패키지의 LinearRegression을 이용해 선형 회귀 모델을 생성해 보겠습니다.

model을 학습은 fit함수를 이용해서 할 수 있습니다.

model.fit(X=..., y=...)

X는 학습에 사용할 데이터를 y는 학습에 사용할 정답입니다.

 
In [11]:
model.fit(X=data, y=y)
#model.fit(data, y)
 
Out[11]:
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
 
 

1.3.2 모델의 식 확인¶

 
 

bias, 편향을 먼저 확인하겠습니다.
sklearn 에서는 intercept_로 확인할 수 있습니다.

 
In [12]:
model.intercept_
 
Out[12]:
1.0000000000000004
 
 

다음은 회귀계수 입니다.
coef_로 확인할 수 있습니다.

 
In [13]:
model.coef_
 
Out[13]:
array([0.6])
 
 

위의 두 결과로 다음과 같은 회귀선을 얻을 수 있습니다.
y = 1.0000000000000004 + 0.6 * x

 
 

1.3.3 예측하기¶

이제 학습된 모델로 예측하는 방법에 대해서 알아보겠습니다.
모델의 예측은 predict 함수를 통해 할 수 있습니다.

model.predict(X=...)

X는 예측하고자 하는 데이터입니다.

 
In [14]:
pred = model.predict(data)
 
 

예측한 결과는 다음과 같습니다.

 
In [15]:
pred
 
Out[15]:
array([1.6, 2.2, 2.8, 3.4])
 
 

1.4 회귀선을 Plot으로 표현하기¶

 
In [16]:
plt.scatter(X, y)
plt.plot(X, pred, color='green')
 
Out[16]:
[<matplotlib.lines.Line2D at 0x20928a16d50>]
 
 
 

2. Multivariate Regression¶

 
 

2.1 Sample Data¶

Multivariate Regression에서 사용할 데이터를 생성하고 학습된 회귀식과 비교해 보겠습니다.

 
In [17]:
bias = 1
beta = np.array([2,3,4,5]).reshape(4, 1)
noise = np.random.randn(100, 1)
 
In [18]:
X = np.random.randn(100, 4)
y = bias + X.dot(beta)
y_with_noise = y + noise
 
In [19]:
X
 
Out[19]:
array([[-0.74057432,  0.45823318,  1.2572469 , -0.44170407],
       [ 0.5413351 ,  0.5672784 , -0.43538181,  0.76681236],
       [ 0.48084856, -0.88128189, -0.68428578,  0.45401778],
       [-0.72516608, -1.43882635,  2.15079225,  0.04869812],
       [ 0.64711296, -1.07448408,  0.24474502,  0.37090588],
       [ 0.41423217,  1.19669386,  1.0824733 ,  0.27691761],
       [ 0.9303673 ,  0.33499986,  1.79409578, -0.95259799],
       [-0.56992369,  0.24691765,  0.51129208, -0.43980433],
       [ 1.79024025,  0.18092909,  0.38195648,  0.51485676],
       [ 0.59167189,  1.9003216 ,  0.60031973, -0.12461753],
       [ 0.14638537,  0.56526801, -0.59881165,  0.3800241 ],
       [-0.62191636,  0.15565243, -1.41513694,  1.9178002 ],
       [ 1.02157999, -0.22717161,  0.45970277,  0.28276229],
       [-1.1789878 ,  1.52858149,  1.17351873, -2.40336246],
       [-1.50338253,  1.63827571,  1.06207973, -1.21583473],
       [-1.46939903,  0.49315748, -1.39057797, -0.07524408],
       [-0.25350979,  1.42584117,  0.7820977 ,  1.62809804],
       [ 0.88045748,  0.84317564, -0.97519168, -1.28824701],
       [-0.00759803, -0.05529913,  0.27129061,  0.41730519],
       [-0.62572604, -0.24468954,  0.79309269, -1.39581176],
       [ 0.73573987, -0.06172918, -0.3672049 ,  0.58134237],
       [-1.28615425,  0.85178815,  0.68446682, -1.34668335],
       [-0.97748383, -1.51492099, -1.07103031,  0.3534274 ],
       [-0.82771752, -1.49650381, -1.51769502, -1.03799842],
       [ 0.66256697, -0.62422805, -0.64376242, -0.68625396],
       [ 0.44122202,  0.8558804 ,  0.14771668, -1.59463314],
       [ 2.63689422, -0.71652769, -0.38792999, -2.40507443],
       [ 0.44150245,  0.43045242, -0.88209488,  1.68726578],
       [ 0.99927274,  0.02555529, -1.28445107, -0.42062073],
       [ 0.2411335 ,  0.80197881, -0.18311936,  1.12545835],
       [-1.16712404,  0.11875   , -0.71326992,  0.03757425],
       [-0.56068183,  0.41330341,  1.54379299,  0.26889845],
       [-1.41343222,  0.34213165,  0.30810781,  2.1475419 ],
       [ 1.45007548,  1.06764732, -0.49998118, -0.35048743],
       [-0.66072427, -2.86887919,  0.84810079,  0.32761252],
       [ 0.77461123,  1.14553304,  0.36818244,  0.84573897],
       [-0.54294848, -0.24272975, -1.02472931, -1.04020844],
       [-0.21859429,  1.07470903,  0.46975773, -1.48787122],
       [ 0.90342123, -0.4278425 ,  0.05522557,  0.63542067],
       [-0.51586728, -0.47974281, -0.53773363, -0.71747224],
       [-0.11895655,  1.22985511, -0.59529878,  1.59596503],
       [ 1.58365672, -0.18513785,  0.01664184,  0.23233448],
       [ 0.17113206,  0.36166644,  0.10056732,  1.40077752],
       [ 1.30539747,  1.68118719, -0.01691732, -0.07586923],
       [ 1.52117725, -0.57051203, -0.47835704, -0.07823544],
       [-0.06546613, -0.21519949, -0.08452669, -2.33712013],
       [-1.38028993, -0.64703717,  0.17090629, -0.50929329],
       [ 1.35604995, -1.06376326, -1.38632217,  1.52555296],
       [ 1.20746151,  2.66912323,  0.111012  , -1.12655055],
       [-0.1203488 , -1.22651695, -0.72269499, -0.61902635],
       [-0.98808119, -0.53241478,  1.18224599,  0.7708145 ],
       [-0.41672036, -0.26689619, -1.95664789,  0.38417318],
       [ 0.83647461,  1.31463184, -1.34398534, -0.58182729],
       [-0.59238583,  0.29186324, -0.99028615, -0.27389121],
       [ 0.95600915,  0.14189529, -0.5806097 , -0.73475932],
       [-0.97007202, -0.13890943,  0.66550248, -1.58259194],
       [-0.98049008, -0.53750607, -1.32857645,  0.878594  ],
       [-0.22055441,  0.44068977,  0.69973888, -0.18154933],
       [ 1.87109284,  0.61404307, -0.81724303, -1.08864352],
       [ 0.80398121,  1.19523417, -0.01146552,  0.23343722],
       [-1.576667  ,  0.62188379, -1.21328967, -2.14360915],
       [ 1.31343337,  0.64783768,  1.16014925, -1.63126822],
       [-0.03428886,  1.41702415, -1.03135455, -0.17010534],
       [-0.00609236,  0.86387628,  1.17238709, -0.6914441 ],
       [ 0.0295551 ,  0.2011155 ,  0.2132878 ,  0.88624586],
       [ 0.04886997,  0.88327758,  0.65919495, -1.17254459],
       [ 0.31177366, -1.43699763, -0.151088  , -0.99412634],
       [ 0.81495268, -0.09735112,  0.21956592, -0.10568017],
       [ 0.16441589, -1.24165574, -0.95066508, -0.41358392],
       [-1.62658109, -0.116679  , -0.34360021,  1.65210199],
       [ 0.48844069, -0.97246307, -0.58324361,  1.14238629],
       [ 1.74461983, -0.84109207, -0.4203053 , -1.52443011],
       [-0.20102392,  0.03271448, -1.09380531, -0.60641662],
       [ 0.27653019, -0.04274745,  0.07166194, -0.15410687],
       [ 0.97356258,  0.25414126,  1.05879595,  0.78107037],
       [-0.54635823,  1.54650131,  1.46080181, -0.40578878],
       [ 1.43086773, -0.21069973, -0.71425605, -0.39001469],
       [-1.12363691, -0.79680479, -0.67386398, -0.2369378 ],
       [ 0.81682211, -0.85001682,  0.77316875,  0.33417035],
       [ 1.03533486, -0.53665376,  0.51443378, -1.26670177],
       [-0.40808224,  0.40808453, -1.30409641, -0.76737038],
       [-0.37076285, -1.80065998, -1.06737769, -0.1261404 ],
       [ 0.4594771 ,  0.85236292,  0.30652867,  0.883799  ],
       [-0.11744223, -0.47305388,  0.58072793,  1.88723219],
       [-0.55393436,  0.51869336,  0.07742329, -1.5313427 ],
       [ 0.74281248, -0.48434702,  0.88518399,  0.089355  ],
       [ 0.64250899,  1.59613679, -1.22869427,  0.1393659 ],
       [-0.74097265,  0.07893833, -1.58730732, -0.04141167],
       [ 0.4676652 ,  2.2467043 , -0.37958554, -0.49243673],
       [ 0.42652303, -0.37825877, -0.08124883, -0.68827697],
       [-0.26473031,  2.21297061,  0.10567097, -0.22285691],
       [ 0.40809498, -0.33301199,  0.57865527,  0.93836114],
       [ 0.41350806,  1.51169216, -0.65265573, -1.23133217],
       [ 1.24648683, -0.64402588, -0.29799678, -3.0663246 ],
       [-0.73077444, -0.200466  , -0.75506598,  3.42526957],
       [-0.91501093,  0.65910168,  0.47102331,  1.05282785],
       [ 0.44404478,  0.85914353, -0.776369  , -0.37351724],
       [-1.53562235,  0.20178569, -0.11016686, -0.19232803],
       [ 0.08044615, -0.88979422, -0.20075015, -1.31246388],
       [-0.23276233, -1.51087601,  0.75087126,  0.45705608]])
 
In [20]:
X[:10]
 
Out[20]:
array([[-0.74057432,  0.45823318,  1.2572469 , -0.44170407],
       [ 0.5413351 ,  0.5672784 , -0.43538181,  0.76681236],
       [ 0.48084856, -0.88128189, -0.68428578,  0.45401778],
       [-0.72516608, -1.43882635,  2.15079225,  0.04869812],
       [ 0.64711296, -1.07448408,  0.24474502,  0.37090588],
       [ 0.41423217,  1.19669386,  1.0824733 ,  0.27691761],
       [ 0.9303673 ,  0.33499986,  1.79409578, -0.95259799],
       [-0.56992369,  0.24691765,  0.51129208, -0.43980433],
       [ 1.79024025,  0.18092909,  0.38195648,  0.51485676],
       [ 0.59167189,  1.9003216 ,  0.60031973, -0.12461753]])
 
In [21]:
y[:10]
 
Out[21]:
array([[ 3.71401813],
       [ 5.87703996],
       [-1.14920281],
       [ 4.0798484 ],
       [ 1.90428314],
       [11.13302717],
       [ 6.27912737],
       [ 0.44705225],
       [ 9.22537749],
       [ 9.66249986]])
 
 

2.2 Multivariate Regression¶

 
In [22]:
model = LinearRegression()
model.fit(X, y_with_noise)
 
Out[22]:
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
 
 

2.3 회귀식 확인하기¶

 
In [23]:
model.intercept_
 
Out[23]:
array([1.11417641])
 
In [24]:
model.coef_
 
Out[24]:
array([[1.99383579, 2.94374717, 3.97346537, 4.84223742]])
 
 

원래 식과 비교한 결과 편향은 잘 맞추지 못했습니다. 다만 회귀 계수의 경우 비교적 정확하게 예측을 하였습니다.

 
 

2.4 통계적 방법¶

이번엔 통계적 방법으로 회귀식을 계산해 보겠습니다.

 
In [25]:
bias_X = np.array([1]*len(X)).reshape(-1, 1)
stat_X = np.hstack([bias_X, X])
X_X_transpose = stat_X.transpose().dot(stat_X)
X_X_transpose_inverse = np.linalg.inv(X_X_transpose)
 
In [26]:
stat_beta = X_X_transpose_inverse.dot(stat_X.transpose()).dot(y_with_noise)
 
In [27]:
stat_beta
 
Out[27]:
array([[1.11417641],
       [1.99383579],
       [2.94374717],
       [3.97346537],
       [4.84223742]])
 
 

3. Polynomial Regression¶

 
 

3.1 Sample Data¶

 
 

비선형 데이터를 생성해 보겠습니다.

 
In [28]:
bias = 1
beta = np.array([2,3]).reshape(2, 1)
noise = np.random.randn(100, 1)
 
In [29]:
X = np.random.randn(100, 1)
X_poly = np.hstack([X, X**2])
 
In [30]:
X_poly[:10]
 
Out[30]:
array([[-0.92491748,  0.85547234],
       [ 1.01010697,  1.02031609],
       [ 0.75532068,  0.57050933],
       [ 0.8542725 ,  0.7297815 ],
       [ 0.70907623,  0.5027891 ],
       [ 0.82063592,  0.67344331],
       [ 0.12924242,  0.0167036 ],
       [ 0.07362991,  0.00542136],
       [-0.25463986,  0.06484146],
       [ 0.06335837,  0.00401428]])
 
In [31]:
y = bias + X_poly.dot(beta)
y_with_noise = y + noise
 
In [32]:
plt.scatter(X, y_with_noise)
 
Out[32]:
<matplotlib.collections.PathCollection at 0x20928ac36d0>
 
 
 

3.2 Polynomial Regression¶

 
 

3.2.1 학습하기¶

 
In [33]:
model = LinearRegression()
model.fit(X_poly, y_with_noise)
 
Out[33]:
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
 
 

3.2.2 회귀식 확인하기¶

 
In [34]:
model.intercept_
 
Out[34]:
array([1.0466608])
 
In [35]:
model.coef_
 
Out[35]:
array([[2.02459592, 2.96023615]])
 
 

3.2.3 예측하기¶

 
In [36]:
pred = model.predict(X_poly)
 
 

3.3 예측값을 Plot으로 확인하기¶

  • 비선형으로 예측하는 것을 확인할 수 있습니다.
 
In [37]:
plt.scatter(X, pred)
 
Out[37]:
<matplotlib.collections.PathCollection at 0x2092ac8ba50>
 
 
In [ ]:
 

'Machine Learning > Linear Regression' 카테고리의 다른 글

당뇨병 예측  (0) 2024.03.12
'Machine Learning/Linear Regression' 카테고리의 다른 글
  • 당뇨병 예측
Juson
Juson
  • Juson
    Juson의 데이터 공부
    Juson
  • 전체
    오늘
    어제
    • 분류 전체보기 (95)
      • RAG (2)
      • AI (2)
        • NLP (0)
        • Generative Model (0)
        • Deep Reinforcement Learning (2)
        • LLM (0)
      • Logistic Optimization (0)
      • Machine Learning (37)
        • Linear Regression (2)
        • Logistic Regression (2)
        • Decision Tree (5)
        • Naive Bayes (1)
        • KNN (2)
        • SVM (2)
        • Clustering (4)
        • Dimension Reduction (3)
        • Boosting (6)
        • Abnomaly Detection (2)
        • Recommendation (4)
        • Embedding & NLP (4)
      • Reinforcement Learning (5)
      • Deep Learning (10)
        • Deep learning Bacis Mathema.. (10)
      • Optimization (2)
        • OR Optimization (0)
        • Convex Optimization (0)
        • Integer Optimization (0)
      • SNA 분석 (0)
      • 포트폴리오 최적화 공부 (0)
        • 최적화 기법 (0)
        • 금융 베이스 (0)
      • Finanancial engineering (0)
      • 프로그래머스 데브코스(Boot camp) (15)
        • SQL (9)
        • Python (5)
        • Machine Learning (1)
      • Python (22)
      • Project (0)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.4
Juson
Linear Regression 실습
상단으로

티스토리툴바