diff --git a/ml3/spiral.py b/ml3/spiral.py index 5292aea..e53b373 100644 --- a/ml3/spiral.py +++ b/ml3/spiral.py @@ -57,7 +57,7 @@ def plot_dataset(X,y): cm = plt.cm.RdBu plt.scatter(X[:,0], X[:,1], c=y, cmap=cm, lw=.5, s=10) -def plot_decision_surface(X, y, classifier): +def plot_decision_surface(X, y, classifier, h=0.02): h = .02 cm = plt.cm.RdBu