You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.1 KiB

import numpy as np
import matplotlib.pyplot as plt
from math import cos, sin
from scipy.constants import golden, pi
def gen_spiral_dataset(n_examples=500, n_classes=2, a=None, b=None, pi_space=3):
n_spirals = n_classes
# default: golden spiral
if a is None:
a = golden
if b is None:
b = 2/pi
theta = np.linspace(0,pi_space*pi, num=n_examples)
xy = np.zeros((n_examples,2))
# logaritmic spirals
x_golden_parametric = lambda a, b, theta: a**(theta*b) * cos(theta)
y_golden_parametric = lambda a, b, theta: a**(theta*b) * sin(theta)
x_golden_parametric = np.vectorize(x_golden_parametric)
y_golden_parametric = np.vectorize(y_golden_parametric)
# rotation matrix
gen_rotation = lambda theta: np.array([[cos(theta), -sin(theta)],[sin(theta), cos(theta)]])
# rotation angles
rot_division = (2*pi) / n_spirals
rot_thetas = [i * rot_division for i in range(n_spirals)]
XY = np.zeros((2, n_examples, n_spirals))
for i in range(n_spirals):
x = x_golden_parametric(a, b, theta)
y = y_golden_parametric(a, b, theta)
xy = np.vstack((x,y))
R = gen_rotation(rot_thetas[i])
xy_ =, xy)
XY[:,:,i] = xy_
return XY
def load_spiral_dataset(n_examples=300, n_classes=2):
XY = gen_spiral_dataset(n_examples, n_classes)
X_s = []
y_s = []
for i in range(XY.shape[2]):
X = XY[:,:,i].T
y = np.array([i] * XY.shape[1]).T
X = np.vstack(X_s)
y = np.hstack(y_s)
return X, y
def plot_dataset(X,y):
cm =
plt.scatter(X[:,0], X[:,1], c=y, cmap=cm, lw=.5, s=10)
def plot_decision_surface(X, y, classifier):
h = .02
cm =
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
z = classifier.predict(np.c_[xx.ravel(), yy.ravel()])#[:, 1]
z = z.reshape(xx.shape)
plt.contourf(xx, yy, z, cmap=cm, alpha=.8)
plot_dataset(X, y)