DecisionTreeClassifier predict_proba returns 0 or 1

Phyast10 picture Phyast10 · Jan 12, 2018 · Viewed 7.3k times · Source

I m trying to use the decision tree classified to identify two classes (renamed 0 and 1) based on certain parameters. I train it using a dataset and then run it on the "test dataset". When I try to calculate the probability for each data point in the test dataset, it returns 0 or 1, only. I wonder what is the problem.

Here is the sample code :

clf=tree.DecisionTreeClassifier(random_state=0) trained=clf.fit(data,identifier) # training data where identifier is 0 or 1 predict=trained.predict(test_data) The results from this are :

In [9]: predict

Out[9]: 
array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
       0, 0, 1, 0, 0, 0])

In [10]: trained.predict_proba(test_data)[:,1]

Out[10]: 
array([ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,
        0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,
        1.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,
        0.,  1.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.])

I would like to generate and ROC, which at this point just returns 3 data points for FPR/TPR.

Here is the complete data set : Identifier is the last column of "data".

Train data :

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma,Class
1.4304664,0.61,2.18,0.3819051,0.99992716,1.93,0
1.6969398,0.54,1.93,0.66479063,0.9999814,2.11,0
2.233997,1.02,3.18,0.55532146,0.9999979,2.07,0
2.230639,0.77,2.34,0.0012237767,1.0,1.81,0
1.7325432,0.71,2.27,0.34395835,1.0,1.9,0
1.8728518,0.8,2.14,0.4255796,1.0,1.96,0
1.9818852,0.7,2.18,-0.08978904,1.0,1.66,0
2.3864453,0.95,2.51,0.109010585,0.98401743,1.81,0
2.5911317,0.94,2.49,0.60381645,0.99991965,2.03,0
1.9564596,0.81,2.29,0.3843,0.9999495,2.08,0
2.1506176,0.93,2.62,0.28551856,0.9999999,1.91,0
1.9069784,0.62,1.76,0.041608978,1.0,1.86,0
1.6216202,0.77,2.11,-0.14271076,1.0,1.7,0
2.276335,0.68,2.14,0.40399882,1.0,2.06,0
2.2430172,1.0,2.94,0.61844856,1.0,2.12,0
1.0226197,0.66,2.07,-0.14886126,1.0,1.84,0
2.2564504,1.06,2.77,0.6974536,0.99844635,2.16,0
2.2819016,0.88,2.37,0.30696234,0.999996,1.86,0
1.4881139,0.7,2.09,0.40853307,1.0,1.82,0
2.4640048,0.9,2.39,0.35103577,1.0,2.02,0
2.656071,0.72,2.29,0.21568911,0.9999046,2.11,0
1.7204628,0.62,2.01,0.19794853,1.0,1.8,0
1.9134961,0.86,2.27,0.37281907,1.0,1.94,0
1.3061943,0.67,2.01,0.3463318,0.99999976,1.86,0
1.8845558,0.64,2.01,0.12364135,0.9999834,1.84,0
2.4409518,1.12,3.31,0.7502838,1.0,2.17,0
1.9501582,0.85,2.34,0.29961613,0.9999974,1.92,0
2.1314192,1.03,2.62,0.69623667,1.0,2.28,0
1.7345899,0.69,2.61,0.38524705,0.99999887,2.09,0
1.7095753,0.75,2.08,0.21696341,0.9999987,1.95,0
1.9115254,0.83,2.17,-0.046689913,1.0,1.85,0
1.565369,0.67,2.01,-0.04827315,0.9999915,1.79,0
2.2971635,0.59,2.1,0.35741857,1.0,2.0,0
3.042759,1.06,2.94,0.70878696,0.9999844,2.15,0
2.340724,0.96,2.74,0.42822766,0.99999416,1.97,0
1.8552977,0.74,2.09,0.07262661,1.0,1.69,0
2.0324602,0.66,2.05,-0.07643526,0.9999982,1.83,0
1.8508979,0.67,1.96,0.054557554,0.99997455,1.75,0
2.7983437,0.96,2.58,0.8554537,0.9999992,2.2,0
2.1728642,1.09,3.05,0.61488354,1.0,2.04,0
3.113785,0.66,1.85,0.48011553,0.99995273,1.95,0
3.0665417,0.78,2.19,0.27814054,1.0,1.86,0
2.0060341,0.83,2.39,0.20785762,0.9999502,1.85,0
2.1786506,0.57,2.0,0.33096096,1.0,1.91,0
1.823961,0.72,1.96,-0.103285044,1.0,1.6,0
1.612012,0.68,2.15,-0.3136376,0.65517294,1.52,0
2.1615896,0.87,2.4,0.47535577,1.0,2.04,0
2.3053634,1.06,2.92,0.67040676,0.9991328,2.15,0
1.7525402,0.73,2.12,0.25563625,0.9999979,1.92,0
2.7306526,0.91,2.35,0.68943393,-0.4308276,2.1,0
2.2549937,1.07,2.91,0.6077795,0.9999626,2.04,0
2.0924683,0.69,2.04,-0.068183094,0.3497915,1.77,0
2.210627,0.84,2.09,0.6309954,0.99999976,1.99,0
2.4609168,0.67,2.08,0.29552716,0.99964327,1.96,0
2.5169518,0.84,2.45,0.35437247,0.9999745,1.92,0
2.1841373,0.9,2.51,0.5617463,1.0,2.15,0
3.0673068,0.8,2.22,0.17641401,1.0,1.9,0
2.6202004,0.97,2.47,0.36663872,1.0,2.03,0
1.9694642,0.95,2.54,0.33140072,0.99998665,2.04,0
1.8766946,0.84,2.32,-0.024992371,0.99999803,1.94,0
2.9352057,1.2,2.96,0.6385377,0.9951195,2.18,0
1.4075257,0.86,2.27,0.046303034,0.9999998,1.81,0
1.8769667,0.6,2.0,0.08842805,0.15410244,1.83,0
1.2585826,0.71,1.96,0.005930161,0.78259146,1.72,0
2.2046561,0.9,2.37,0.62021697,1.0,2.07,0
1.0217602,0.49,1.89,-0.26944694,0.9999997,1.66,0
2.1021683,1.05,2.78,0.5306551,1.0,2.14,0
2.4789429,0.94,2.52,0.34224525,0.9999965,2.01,0
2.1449182,0.8,2.32,0.37609425,0.9997282,2.25,0
2.7071185,0.83,2.36,0.75363404,1.0,2.31,0
1.8445525,1.04,2.76,0.6075378,0.88632137,2.14,0
1.6024263,1.09,2.63,0.64461184,1.0,2.18,0
2.0292685,0.53,2.15,0.090091705,1.0,1.92,0
2.0858748,0.71,1.86,0.14351326,0.9999994,1.88,0
2.1292083,0.81,2.31,0.33257455,1.0,1.95,0
1.6344122,0.84,2.38,0.6371139,0.9999998,2.11,0
1.7532507,0.75,2.04,0.16182575,1.0,1.78,0
2.2479355,0.97,2.72,0.41953298,1.0,2.04,0
2.5790315,1.07,2.96,0.7216893,0.9999953,2.11,0
3.0039942,1.03,2.44,0.8042694,0.9998856,2.25,1
3.7599833,1.16,3.23,0.9095345,0.66683024,2.39,1
2.8912013,1.05,2.67,0.85215354,0.9967052,2.27,1
3.8784094,1.11,3.18,0.6971026,1.0,2.19,1
2.1862392,1.13,2.7,0.65855825,1.0,2.28,1
2.7684402,1.16,2.79,0.9261603,-0.9540385,2.35,1
1.7551649,0.56,2.18,0.23092282,1.0,1.98,1
2.804592,1.13,2.98,0.84827685,1.0,2.3,1
1.9874831,1.0,2.98,0.87599415,1.0,2.21,1
2.5059428,1.16,2.79,0.97649753,0.9997586,2.42,1
2.812127,1.12,3.11,0.87392867,1.0,2.21,1
2.9445121,1.06,3.17,0.8849491,1.0,2.41,1
2.7388847,1.11,2.78,0.84986275,0.96669436,2.32,1
2.1416433,1.1,3.61,0.7671358,0.9999998,2.29,1
2.3661094,1.05,3.16,0.73194104,0.99990827,2.14,1
2.761189,1.09,2.81,0.7681978,-0.99955946,2.23,1
2.6658804,1.02,3.36,0.8036201,0.98403203,2.28,1
2.720667,0.99,2.78,0.97055733,0.9781505,2.48,1
2.6812658,0.98,3.05,0.73290765,1.0,2.09,1
1.4784714,0.62,1.97,0.418,1.0,2.02,0
1.7488811,0.7,2.05,0.418,0.99999624,2.02,0

test data :

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma
1.6724254,0.95,2.58,0.92031854,1.0,2.15
2.552926,0.93,2.74,0.63588345,-0.30092865,2.18
2.5737462,0.86,2.22,0.43023747,1.0,2.08
2.1701677,0.62,2.19,0.6892167,1.0,2.15
3.6152358,0.96,2.58,0.67760235,0.99704355,2.06
3.6193092,0.82,2.34,0.4083981,0.9973078,2.04
2.0209844,1.02,2.86,0.8595182,-0.9979041,2.36
2.166221,1.07,3.0,0.7177616,-0.99961376,2.3
2.7933478,0.94,2.4,0.678935,1.0,2.12
2.2969048,0.86,2.29,0.18689133,1.0,1.96
3.1255674,1.15,2.77,0.9290483,0.6387009,2.28
2.3548958,1.01,2.46,0.75331503,-1.0,2.21
3.9791226,1.15,3.04,0.87006325,-0.99919724,2.43
2.3430493,0.85,2.42,0.81132597,-0.9999996,2.04
3.7431624,0.79,2.57,0.704,0.99952716,2.20784
3.1846259,1.14,2.85,0.9104803,0.99891067,2.3
3.1416001,0.73,2.26,0.5679769,1.0,1.98
2.670179,0.85,2.66,0.7376513,0.97939825,2.1
3.010911,0.79,2.38,0.21750104,0.21187924,1.82
1.4430648,0.9,2.38,0.7361963,0.999758,2.11
2.8149416,1.07,2.62,0.94750744,0.9967568,2.4
3.8395922,1.09,2.91,0.27485812,0.99887043,2.05
3.1686394,0.66,2.11,0.529385,1.0,1.9
3.190167,1.09,3.1,0.8501991,0.9507157,2.23
3.8597586,1.13,3.64,0.89043206,0.17880388,2.42
2.1516426,0.85,2.24,0.6673518,0.9985168,2.2
2.1318088,0.98,2.64,0.85542095,1.0,2.22
1.6740437,0.97,2.99,0.86632746,0.9983954,2.41
4.273427,1.01,2.71,0.8941501,0.64256436,2.47
2.284782,0.92,2.7,0.5820462,0.6981752,2.1
3.343603,1.06,2.84,0.6901738,0.83269715,2.13
5.766362,1.2,3.74,0.99009913,0.99998844,2.49
2.1547525,0.95,3.02,0.75229234,0.99604213,2.57
2.9853358,0.91,2.37,0.62881154,-0.98792726,2.06
2.8614197,0.82,2.15,0.75643075,1.0,2.19
3.6815813,1.14,3.24,0.8886577,-0.030438267,2.39
4.539201,1.17,2.83,0.93989134,0.23378997,2.55
3.35261,1.1,2.73,0.9184936,0.9998006,2.41
3.6697345,1.16,3.57,0.9515105,0.9999988,2.43
1.9781204,0.91,2.85,-0.06649571,0.9999991,1.7
2.6618617,1.1,3.24,0.8348949,-0.9834342,2.29
3.8140056,1.18,3.25,0.8766021,1.0,2.39
2.1926181,1.05,2.3,0.6880097,1.0,2.3
2.0248337,0.83,2.29,0.3604591,0.46159065,2.05
3.904931,1.13,2.46,0.9100119,1.0,2.32
1.9945884,0.94,2.5,0.4632657,0.9869119,2.05
3.3342967,1.1,3.04,0.51323855,-0.5262294,2.23
2.3138714,0.91,2.36,0.90414697,0.9999977,2.29
2.3118904,1.04,3.01,0.87289846,0.998577,2.29
2.246307,1.07,2.72,0.6147379,0.9999993,2.11
1.6369493,0.89,2.34,0.61421084,0.9997295,2.22
3.6198807,0.93,2.62,0.7463702,0.9994778,2.07

Answer

kazemakase picture kazemakase · Jan 15, 2018

There is no problem - the tree behaves exactly as expected.

A decision tree computes the class probability from the number of samples of each class that fall into a given leaf.

The documentation says:

The default values for the parameters controlling the size of the trees (e.g. max_depth, min_samples_leaf, etc.) lead to fully grown and unpruned trees

I.e. the tree is grown until it perfectly (over)fits the training data. This means that all training samples in each leaf are of the same class and a test sample either matches that class (p=1) or does not (p=0).