Advertisement
VisualPaul

Untitled

Apr 8th, 2016
377
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.32 KB | None | 0 0
  1. from sklearn.cluster import KMeans
  2. from sklearn.preprocessing import OneHotEncoder
  3. from geopy.distance import vincenty
  4. from sklearn.neighbors import BallTree
  5. import builtins
  6.  
  7. smallKMeans = KMeans(n_clusters=3).fit(array([[ 55.73008165, 37.59531199], [ 59.91301691, 30.31944249],[ 55.67814337, 46.11249841]]))
  8. holidays = "1.01,2.01,3.01,4.01,5.01,6.01,7.01,8.01,23.02,8.03,9.03,10.03,1.05,2.05,3.05,4.05,9.05,10.05,11.05,12.06,13.06,14.06,15.06".split(',')
  9. holidays = set(tuple(map(int, x.split('.'))) for x in holidays)
  10.  
  11. def get_features(data):
  12. dist = data.dist.values
  13. lat, lon = data.lat.values, data.lon.values
  14. weekday, month = data.day_of_week.values, data.month.values
  15. hourx, houry = cos(data.hour / 23), sin(data.hour / 23)
  16. hour = data.hour
  17. hota, hotb, hotc = zeros_like(hour, dtype=float32), zeros_like(hour, dtype=float32), zeros_like(hour, dtype=float32)
  18.  
  19. hota[(data.f_class == 'econom').values] += 1.00
  20. hota[(data.s_class == 'econom').values] += 0.50
  21. hota[(data.t_class == 'econom').values] += 0.25
  22.  
  23. hotb[(data.f_class == 'business').values] += 1.00
  24. hotb[(data.s_class == 'business').values] += 0.50
  25. hotb[(data.t_class == 'business').values] += 0.25
  26.  
  27. hotc[(data.f_class == 'vip').values] += 1.00
  28. hotc[(data.s_class == 'vip').values] += 0.50
  29. hotc[(data.t_class == 'vip').values] += 0.25
  30.  
  31. isHoliday = array([[1 if x in holidays else 0] for x in zip(data['day'], data['month'])])
  32. city = smallKMeans.predict(ds[['lat', 'lon']])
  33. smallClusters = OneHotEncoder().fit_transform(city.reshape(-1, 1)).toarray()
  34. cityDistance = array([vincenty(smallKMeans.cluster_centers_[city[i]], (data['lat'][i], data['lon'][i])).meters for i in range(len(data))])
  35. ballTree = BallTree(data[['lat', 'lon']])
  36. coord = list(zip(data['lat'], data['lon']))
  37. sumDist = array([builtins.sum(vincenty(coord[i], coord[x]).meters for x in (ballTree.query((coord[i],), 5)[1][0])) for i in range(len(coord))])
  38.  
  39. weekday = OneHotEncoder().fit_transform(ds['day_of_week'].reshape(-1, 1)).toarray()
  40. features = array(list(zip(dist, lat, lon, month, hourx, houry, hour, hota, hotb, hotc)))
  41. features = hstack((features, smallClusters, isHoliday, weekday, city.reshape(-1, 1), cityDistance.reshape(-1, 1), sumDist.reshape(-1, 1)))
  42. return features
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement