본문 바로가기

Python/MachineLearning

Machine Learning(ML)_Taitanic예제

decision tree 알고리즘 -> depth를 정해줘야함

 

결정 트리 학습법 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전. 결정 트리 학습법(decision tree learning)은 어떤 항목에 대한 관측값과 목표값을 연결시켜주는 예측 모델로써 결정 트리를 사용한다. 이는 통계학과 데이터 마이닝,

ko.wikipedia.org

 

 

2. Titanic Data분류예측

1. Load Titanic Datasets

In [1]:
import pandas as pd 

train = pd.read_csv('data/titanic/train.csv', index_col='PassengerId')
print(train.shape)
print(train.info())
train.head()
(891, 11)
<class 'pandas.core.frame.DataFrame'>
Int64Index: 891 entries, 1 to 891
Data columns (total 11 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Survived  891 non-null    int64  
 1   Pclass    891 non-null    int64  
 2   Name      891 non-null    object 
 3   Sex       891 non-null    object 
 4   Age       714 non-null    float64
 5   SibSp     891 non-null    int64  
 6   Parch     891 non-null    int64  
 7   Ticket    891 non-null    object 
 8   Fare      891 non-null    float64
 9   Cabin     204 non-null    object 
 10  Embarked  889 non-null    object 
dtypes: float64(2), int64(4), object(5)
memory usage: 83.5+ KB
None
Out[1]:
Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
PassengerId
1 0 3 Braund, Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 0 PC 17599 71.2833 C85 C
3 1 3 Heikkinen, Miss. Laina female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S
4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 0 113803 53.1000 C123 S
5 0 3 Allen, Mr. William Henry male 35.0 0 0 373450 8.0500 NaN S
In [2]:
test = pd.read_csv('data/titanic/test.csv', index_col='PassengerId')
print(test.shape)
print(test.info())
test.head()
(418, 10)
<class 'pandas.core.frame.DataFrame'>
Int64Index: 418 entries, 892 to 1309
Data columns (total 10 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Pclass    418 non-null    int64  
 1   Name      418 non-null    object 
 2   Sex       418 non-null    object 
 3   Age       332 non-null    float64
 4   SibSp     418 non-null    int64  
 5   Parch     418 non-null    int64  
 6   Ticket    418 non-null    object 
 7   Fare      417 non-null    float64
 8   Cabin     91 non-null     object 
 9   Embarked  418 non-null    object 
dtypes: float64(2), int64(3), object(5)
memory usage: 35.9+ KB
None
Out[2]:
Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
PassengerId
892 3 Kelly, Mr. James male 34.5 0 0 330911 7.8292 NaN Q
893 3 Wilkes, Mrs. James (Ellen Needs) female 47.0 1 0 363272 7.0000 NaN S
894 2 Myles, Mr. Thomas Francis male 62.0 0 0 240276 9.6875 NaN Q
895 3 Wirz, Mr. Albert male 27.0 0 0 315154 8.6625 NaN S
896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female 22.0 1 1 3101298 12.2875 NaN S
In [3]:
# null data counting
train.isnull().sum()
Out[3]:
Survived      0
Pclass        0
Name          0
Sex           0
Age         177
SibSp         0
Parch         0
Ticket        0
Fare          0
Cabin       687
Embarked      2
dtype: int64
In [4]:
test.isnull().sum()
Out[4]:
Pclass        0
Name          0
Sex           0
Age          86
SibSp         0
Parch         0
Ticket        0
Fare          1
Cabin       327
Embarked      0
dtype: int64

2. Data PreProcessing (데이터 전처리)

  • 문자열 데이터를 숫자로 변환
  • One Hot Encoding
  • null data 처리

2.1 성별(Sex) Encoding

  • 'male' => 0 , 'female' => 1
In [5]:
train['Sex'].unique()
Out[5]:
array(['male', 'female'], dtype=object)
In [6]:
train['Sex'].value_counts()
Out[6]:
male      577
female    314
Name: Sex, dtype: int64
In [7]:
# Sex 컬럼의 값을 변경
train.loc[train['Sex'] == 'male', 'Sex'] = 0
train.loc[train['Sex'] == 'female', 'Sex'] = 1

train['Sex'].unique()
Out[7]:
array([0, 1], dtype=object)
In [8]:
train.head(2)
Out[8]:
Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
PassengerId
1 0 3 Braund, Mr. Owen Harris 0 22.0 1 0 A/5 21171 7.2500 NaN S
2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... 1 38.0 1 0 PC 17599 71.2833 C85 C
In [9]:
# Sex 컬럼의 값을 변경
test.loc[test['Sex'] == 'male', 'Sex'] = 0
test.loc[test['Sex'] == 'female', 'Sex'] = 1

test['Sex'].unique()
Out[9]:
array([0, 1], dtype=object)

2.2 Fare 컬럼의 null data 처리

In [10]:
test.loc[test['Fare'].isnull(),'Fare'] = 0
test.loc[test['Fare'].isnull()]
Out[10]:
Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
PassengerId

2.3 Embarked 컬럼 처리

  • One Hot Encoding
  • C=0, S=1, Q=2 (X)
  • C=[True,False,False], S=[False,True,False], Q=[False,False,True] (O)
  • Embarked_C, Embarked_S, Embarked_Q 컬럼 3개 추가함
In [11]:
train['Embarked'].value_counts()
Out[11]:
S    644
C    168
Q     77
Name: Embarked, dtype: int64
In [12]:
train['Embarked_C'] = train['Embarked'] == 'C'
train['Embarked_S'] = train['Embarked'] == 'S'
train['Embarked_Q'] = train['Embarked'] == 'Q'

print(train.shape)
train[['Embarked', 'Embarked_C', 'Embarked_S', 'Embarked_Q']].head()
train[['Embarked', 'Embarked_C', 'Embarked_S', 'Embarked_Q']].tail()
(891, 14)
Out[12]:
Embarked Embarked_C Embarked_S Embarked_Q
PassengerId
887 S False True False
888 S False True False
889 S False True False
890 C True False False
891 Q False False True
In [13]:
test['Embarked_C'] = test['Embarked'] == 'C'
test['Embarked_S'] = test['Embarked'] == 'S'
test['Embarked_Q'] = test['Embarked'] == 'Q'

print(test.shape)
test[['Embarked', 'Embarked_C', 'Embarked_S', 'Embarked_Q']].head()
test[['Embarked', 'Embarked_C', 'Embarked_S', 'Embarked_Q']].tail()
(418, 13)
Out[13]:
Embarked Embarked_C Embarked_S Embarked_Q
PassengerId
1305 S False True False
1306 C True False False
1307 S False True False
1308 S False True False
1309 C True False False

2.4 Age 컬럼 처리

  • null 값을 전체 나이의 평균 값을 계산해서 채워넣기.
In [14]:
mean_age = train['Age'].mean()
mean_age
Out[14]:
29.69911764705882
In [15]:
# Age 컬럼의 값이 null row를 평균 나이로 수정하기
train.loc[train['Age'].isnull(),'Age'] = mean_age
train.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 891 entries, 1 to 891
Data columns (total 14 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Survived    891 non-null    int64  
 1   Pclass      891 non-null    int64  
 2   Name        891 non-null    object 
 3   Sex         891 non-null    object 
 4   Age         891 non-null    float64
 5   SibSp       891 non-null    int64  
 6   Parch       891 non-null    int64  
 7   Ticket      891 non-null    object 
 8   Fare        891 non-null    float64
 9   Cabin       204 non-null    object 
 10  Embarked    889 non-null    object 
 11  Embarked_C  891 non-null    bool   
 12  Embarked_S  891 non-null    bool   
 13  Embarked_Q  891 non-null    bool   
dtypes: bool(3), float64(2), int64(4), object(5)
memory usage: 86.1+ KB
In [16]:
test_mean_age = test['Age'].mean()
test.loc[test['Age'].isnull(),'Age'] = test_mean_age
test.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 418 entries, 892 to 1309
Data columns (total 13 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Pclass      418 non-null    int64  
 1   Name        418 non-null    object 
 2   Sex         418 non-null    object 
 3   Age         418 non-null    float64
 4   SibSp       418 non-null    int64  
 5   Parch       418 non-null    int64  
 6   Ticket      418 non-null    object 
 7   Fare        418 non-null    float64
 8   Cabin       91 non-null     object 
 9   Embarked    418 non-null    object 
 10  Embarked_C  418 non-null    bool   
 11  Embarked_S  418 non-null    bool   
 12  Embarked_Q  418 non-null    bool   
dtypes: bool(3), float64(2), int64(3), object(5)
memory usage: 37.1+ KB

3. Data Visualization(시각화)

  • countplot - 막대그래프, x축이나 y축 중에서 하나만 설정할 수 있다.
  • barplot - 막대그래프, x축 y축 둘다 설정할 수 있다.
  • pointplot - 선그래프
  • distplot - 히스토그램(분포도)
  • lmplot - 산점도(scatter plot)
In [17]:
%matplotlib inline
import seaborn as sns
In [18]:
# Embarked 컬럼에 대해서 countplot
sns.countplot(data=train, x='Embarked')
Out[18]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7a16ab00>
In [19]:
# 생존여부와 Embarked 컬럼의 연관성
sns.countplot(data=train, x='Embarked', hue='Survived')
Out[19]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7a215160>
In [20]:
sns.countplot(data=train, x='Pclass')
Out[20]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b25d828>
In [21]:
sns.countplot(data=train, x='Pclass', hue='Survived')
Out[21]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7a215b70>
In [22]:
sns.countplot(data=train, x='Sex')
Out[22]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b31de48>
In [23]:
sns.countplot(data=train, x='Sex', hue='Survived')
Out[23]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b366198>
In [24]:
import warnings
warnings.filterwarnings(action='ignore')
In [25]:
# Pclass와 Fare와의 연관관계
sns.barplot(data=train, x='Pclass', y='Fare')
Out[25]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b3b69e8>
In [26]:
sns.barplot(data=train, x='Pclass', y='Fare', hue='Survived')
Out[26]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b413f98>
In [27]:
sns.pointplot(data=train, x='Pclass', y='Fare', hue='Survived')
Out[27]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b489550>
In [28]:
sns.distplot(train['Age'], hist=True)
Out[28]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b4f23c8>
In [29]:
sns.distplot(train['Fare'], hist=False)
Out[29]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b489c50>
In [30]:
# Fare가 100$ 보다 작은 데이터 추출
low_fare = train.loc[train['Fare'] < 100]
print(low_fare.shape)
sns.distplot(low_fare['Fare'], hist=False)
(838, 14)
Out[30]:
<matplotlib.axes._subplots.AxesSubplot at 0x25a7b848898>
In [31]:
sns.lmplot(data=train, x='Age', y='Fare', hue='Survived')
Out[31]:
<seaborn.axisgrid.FacetGrid at 0x25a7b88fba8>
In [32]:
sns.lmplot(data=low_fare, x='Age', y='Fare', hue='Survived')
Out[32]:
<seaborn.axisgrid.FacetGrid at 0x25a7b88f4e0>

4. Train & Predict

  • Feature Engineering
    • Model에서 사용할 feature(입력데이터)를 추출하기
  • X_train, y_train, X_test 생성하기
  • Decision Tree(의사결정트리)알고리즘 : DecisionTreeClassifier클래스 사용
In [33]:
train.columns
Out[33]:
Index(['Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp', 'Parch', 'Ticket',
       'Fare', 'Cabin', 'Embarked', 'Embarked_C', 'Embarked_S', 'Embarked_Q'],
      dtype='object')
In [34]:
feature_names = ['Pclass', 'Sex', 'Fare', 'Embarked_C', 'Embarked_S', 'Embarked_Q']
feature_names
Out[34]:
['Pclass', 'Sex', 'Fare', 'Embarked_C', 'Embarked_S', 'Embarked_Q']
In [35]:
# X_train 생성
X_train = train[feature_names]
print(X_train.shape)
X_train.head()
(891, 6)
Out[35]:
Pclass Sex Fare Embarked_C Embarked_S Embarked_Q
PassengerId
1 3 0 7.2500 False True False
2 1 1 71.2833 True False False
3 3 1 7.9250 False True False
4 1 1 53.1000 False True False
5 3 0 8.0500 False True False
In [36]:
# X_test 생성
X_test = test[feature_names]
print(X_test.shape)
X_test.head()
(418, 6)
Out[36]:
Pclass Sex Fare Embarked_C Embarked_S Embarked_Q
PassengerId
892 3 0 7.8292 False False True
893 3 1 7.0000 False True False
894 2 0 9.6875 False False True
895 3 0 8.6625 False True False
896 3 1 12.2875 False True False
In [37]:
# y_train 생성
label_name = 'Survived'
y_train = train[label_name]
print(y_train.shape)
y_train.head()
(891,)
Out[37]:
PassengerId
1    0
2    1
3    1
4    1
5    0
Name: Survived, dtype: int64
In [38]:
# Decision Tree 알고리즘 객체 생성
from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier(max_depth=5)
model
Out[38]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=5,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')
In [39]:
# 학습하기
model.fit(X_train, y_train)
Out[39]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=5,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')
In [40]:
!pip show graphviz
Name: graphviz
Version: 0.14
Summary: Simple Python interface for Graphviz
Home-page: https://github.com/xflr6/graphviz
Author: Sebastian Bank
Author-email: sebastian.bank@uni-leipzig.de
License: MIT
Location: c:\anaconda3\lib\site-packages
Requires: 
Required-by: 
In [44]:
from sklearn.tree import export_graphviz
import graphviz

export_graphviz(model, feature_names=feature_names, class_names=['Perished','Survived'], out_file='decision-tree.dot')

with open('decision-tree.dot') as file:
    dot_graph = file.read()
    
graphviz.Source(dot_graph)
Out[44]:
Tree 0 Sex <= 0.5 gini = 0.473 samples = 891 value = [549, 342] class = Perished 1 Fare <= 26.269 gini = 0.306 samples = 577 value = [468, 109] class = Perished 0->1 True 26 Pclass <= 2.5 gini = 0.383 samples = 314 value = [81, 233] class = Survived 0->26 False 2 Fare <= 7.91 gini = 0.226 samples = 415 value = [361, 54] class = Perished 1->2 17 Fare <= 26.469 gini = 0.448 samples = 162 value = [107, 55] class = Perished 1->17 3 Fare <= 7.865 gini = 0.143 samples = 180 value = [166, 14] class = Perished 2->3 10 Embarked_C <= 0.5 gini = 0.282 samples = 235 value = [195, 40] class = Perished 2->10 4 Fare <= 7.762 gini = 0.167 samples = 141 value = [128, 13] class = Perished 3->4 7 Embarked_C <= 0.5 gini = 0.05 samples = 39 value = [38, 1] class = Perished 3->7 5 gini = 0.134 samples = 111 value = [103, 8] class = Perished 4->5 6 gini = 0.278 samples = 30 value = [25, 5] class = Perished 4->6 8 gini = 0.0 samples = 34 value = [34, 0] class = Perished 7->8 9 gini = 0.32 samples = 5 value = [4, 1] class = Perished 7->9 11 Fare <= 7.988 gini = 0.258 samples = 217 value = [184, 33] class = Perished 10->11 14 Fare <= 8.59 gini = 0.475 samples = 18 value = [11, 7] class = Perished 10->14 12 gini = 0.473 samples = 13 value = [8, 5] class = Perished 11->12 13 gini = 0.237 samples = 204 value = [176, 28] class = Perished 11->13 15 gini = 0.0 samples = 1 value = [0, 1] class = Survived 14->15 16 gini = 0.457 samples = 17 value = [11, 6] class = Perished 14->16 18 gini = 0.0 samples = 4 value = [0, 4] class = Survived 17->18 19 Pclass <= 1.5 gini = 0.437 samples = 158 value = [107, 51] class = Perished 17->19 20 Fare <= 387.665 gini = 0.471 samples = 108 value = [67, 41] class = Perished 19->20 23 Fare <= 63.023 gini = 0.32 samples = 50 value = [40, 10] class = Perished 19->23 21 gini = 0.465 samples = 106 value = [67, 39] class = Perished 20->21 22 gini = 0.0 samples = 2 value = [0, 2] class = Survived 20->22 24 gini = 0.369 samples = 41 value = [31, 10] class = Perished 23->24 25 gini = 0.0 samples = 9 value = [9, 0] class = Perished 23->25 27 Fare <= 28.856 gini = 0.1 samples = 170 value = [9, 161] class = Survived 26->27 38 Fare <= 23.35 gini = 0.5 samples = 144 value = [72, 72] class = Perished 26->38 28 Fare <= 28.231 gini = 0.18 samples = 70 value = [7, 63] class = Survived 27->28 33 Fare <= 149.035 gini = 0.039 samples = 100 value = [2, 98] class = Survived 27->33 29 Fare <= 26.125 gini = 0.159 samples = 69 value = [6, 63] class = Survived 28->29 32 gini = 0.0 samples = 1 value = [1, 0] class = Perished 28->32 30 gini = 0.183 samples = 59 value = [6, 53] class = Survived 29->30 31 gini = 0.0 samples = 10 value = [0, 10] class = Survived 29->31 34 gini = 0.0 samples = 81 value = [0, 81] class = Survived 33->34 35 Fare <= 152.506 gini = 0.188 samples = 19 value = [2, 17] class = Survived 33->35 36 gini = 0.444 samples = 3 value = [2, 1] class = Perished 35->36 37 gini = 0.0 samples = 16 value = [0, 16] class = Survived 35->37 39 Embarked_S <= 0.5 gini = 0.484 samples = 117 value = [48, 69] class = Survived 38->39 46 Embarked_S <= 0.5 gini = 0.198 samples = 27 value = [24, 3] class = Perished 38->46 40 Fare <= 15.621 gini = 0.417 samples = 54 value = [16, 38] class = Survived 39->40 43 Fare <= 10.825 gini = 0.5 samples = 63 value = [32, 31] class = Perished 39->43 41 gini = 0.454 samples = 46 value = [16, 30] class = Survived 40->41 42 gini = 0.0 samples = 8 value = [0, 8] class = Survived 40->42 44 gini = 0.482 samples = 37 value = [22, 15] class = Perished 43->44 45 gini = 0.473 samples = 26 value = [10, 16] class = Survived 43->45 47 Fare <= 26.638 gini = 0.5 samples = 2 value = [1, 1] class = Perished 46->47 50 Fare <= 31.331 gini = 0.147 samples = 25 value = [23, 2] class = Perished 46->50 48 gini = 0.0 samples = 1 value = [0, 1] class = Survived 47->48 49 gini = 0.0 samples = 1 value = [1, 0] class = Perished 47->49 51 gini = 0.0 samples = 14 value = [14, 0] class = Perished 50->51 52 gini = 0.298 samples = 11 value = [9, 2] class = Perished 50->52
In [45]:
# 예측하기
predictions = model.predict(X_test)
print(predictions.shape)
predictions
(418,)
Out[45]:
array([0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0,
       1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1,
       1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0,
       1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
       0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,
       1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0,
       1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
       0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
       0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,
       1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0,
       0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0,
       1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
       0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0],
      dtype=int64)

5. Submission (제출하기)

In [46]:
submit = pd.read_csv('data/titanic/gender_submission.csv', index_col='PassengerId')
print(submit.shape)
submit.head()
(418, 1)
Out[46]:
Survived
PassengerId
892 0
893 1
894 0
895 0
896 1
In [48]:
submit['Survived'] = predictions
print(submit.shape)
submit.head()
(418, 1)
Out[48]:
Survived
PassengerId
892 0
893 0
894 0
895 0
896 1
In [49]:
# 제출할 csv파일 생성하기
submit.to_csv('data/titanic/titanic01.csv')

'Python > MachineLearning' 카테고리의 다른 글

Machine Learning(ML)_Taitanic예제  (0) 2020.08.18
Machine Learning(ML)_iris_data예제  (0) 2020.08.18
Machine Learning(ML)  (0) 2020.08.17