# -*- coding: utf-8 -*-
"""
Created on Tue May 22 15:56:56 2018
@author: dell
"""
import numpy as np
class Perceptron(object):
def __init__(self,eta=0.01,n_iter=10):
self.eta=eta;
self.n_iter=n_iter
pass
def fit(self,X,y):
self.w_=np.zero(1+x.shape[1]);
self.errors_=[]
for _ in range(self,n_iter):
errors_=0
for Xi,target in zip(x,y):
update=self.eta*(target-self.predict(Xi))
self.w_[1:]+=update*Xi
self.w_[0]+=update;
errors+=int(update!=0.0)
self.errors_.append(errors)
pass
pass
def net_input(self,X):
return np.dot(x,self.w_[1:])+self.w_[0]
pass
def predict(self,x):
return np.where(self.net_input(x)>=0.0,1,-1)
pass
pass
file="iris.date.csv.txt"
import pandas as pd
df=pd.read_csv(file,header=None)
import matplotlib.pyplot as plt
import numpy as np
y=df.loc[0:100,4].values
y=np.where(y=='Iris-setosa',-1,1)
print (y)
X=df.iloc[0:100,[0,2]].values
print (X)
plt.scatter(X[0:49,0],X[0:49,1],color='red',market='o',label='setosa')
plt.scatter(X[50:100,0],X[50:100,1],color='blue',market='x',label='versicolor')
plt.xlabel('花瓣长度')
plt,ylabel('花茎长度')
plt,legend(loc='upper left')
plt.show()
产品经理不是经理