[Documentation] [TitleIndex] [WordIndex

ml_classifiers

Package Summary

ml_classifiers

Overview

This package provides a ROS service for interfacing with various machine learning algorithms for supervised classification. Three example classifiers are included, including a nearest neighbor and support vector machine classifier, but additional classifiers can be added easily though pluginlib.

Installation

$ sudo apt-get install ros-fuerte-ml-classifiers

$ sudo apt-get install ros-groovy-ml-classifiers

$ sudo apt-get install ros-hydro-ml-classifiers

ROS API

ml_classifiers provides a general interface for all plugin classifiers to implement:

Plugins

ml_classifiers comes with 3 plugins built in: A "zero" classifier that classifies everything as class 0 no matter what, a nearest neighbor classifier, and an SVM classifier based on libSVM. More classifiers can be added easily though pluginlib and by implementing the above interface.

Usage

Simply compile all additional desired plugins and then:

$ roslaunch ml_classifiers classifier_server.launch

An example python script to wrap the service calls and test the SVM classifier:

   1 import roslib; roslib.load_manifest('ml_classifiers')
   2 import rospy
   3 import ml_classifiers.srv
   4 import ml_classifiers.msg
   5 
   6 #Wrapper for calls to ROS classifier service and management of classifier data
   7 class ClassifierWrapper:
   8 
   9     def __init__(self):
  10         #Set up Classifier service handles
  11         print 'Waiting for Classifier services...'
  12         rospy.wait_for_service("/ml_classifiers/create_classifier")
  13         self.add_class_data = rospy.ServiceProxy(
  14             "/ml_classifiers/add_class_data",
  15             ml_classifiers.srv.AddClassData, persistent=True)
  16         self.classify_data = rospy.ServiceProxy(
  17             "/ml_classifiers/classify_data",
  18             ml_classifiers.srv.ClassifyData, persistent=True)
  19         self.clear_classifier = rospy.ServiceProxy(
  20             "/ml_classifiers/clear_classifier",
  21             ml_classifiers.srv.ClearClassifier, persistent=True)
  22         self.create_classifier = rospy.ServiceProxy(
  23             "/ml_classifiers/create_classifier",
  24             ml_classifiers.srv.CreateClassifier, persistent=True)
  25         self.load_classifier = rospy.ServiceProxy(
  26             "/ml_classifiers/load_classifier",
  27             ml_classifiers.srv.LoadClassifier, persistent=True)
  28         self.save_classifier = rospy.ServiceProxy(
  29             "/ml_classifiers/save_classifier",
  30             ml_classifiers.srv.SaveClassifier, persistent=True)
  31         self.train_classifier = rospy.ServiceProxy(
  32             "/ml_classifiers/train_classifier",
  33             ml_classifiers.srv.TrainClassifier, persistent=True)
  34         print 'OK\n'
  35 
  36 
  37     def addClassDataPoint(self, identifier, target_class, p):
  38         req = ml_classifiers.srv.AddClassDataRequest()
  39         req.identifier = identifier
  40         dp = ml_classifiers.msg.ClassDataPoint()
  41         dp.point = p
  42         dp.target_class = target_class
  43         req.data.append(dp)
  44         resp = self.add_class_data(req)
  45 
  46 
  47     def addClassDataPoints(self, identifier, target_classes, pts):
  48         req = ml_classifiers.srv.AddClassDataRequest()
  49         req.identifier = identifier
  50         for i in xrange(len(pts)):
  51             dp = ml_classifiers.msg.ClassDataPoint()
  52             dp.point = pts[i]
  53             dp.target_class = target_classes[i]
  54             req.data.append(dp)
  55         resp = self.add_class_data(req)
  56 
  57 
  58     def classifyPoint(self, identifier, p):
  59         req = ml_classifiers.srv.ClassifyDataRequest()
  60         req.identifier = identifier
  61         dp = ml_classifiers.msg.ClassDataPoint()
  62         dp.point = p
  63         req.data.append(dp)
  64         resp = self.classify_data(req)
  65         return resp.classifications[0]
  66 
  67 
  68     def classifyPoints(self, identifier, pts):
  69         req = ml_classifiers.srv.ClassifyDataRequest()
  70         req.identifier = identifier
  71         for p in pts:
  72             dp = ml_classifiers.msg.ClassDataPoint()
  73             dp.point = p
  74             req.data.append(dp)
  75 
  76         resp = self.classify_data(req)
  77         return resp.classifications
  78 
  79 
  80     def clearClassifier(self, identifier):
  81         req = ml_classifiers.srv.ClearClassifierRequest()
  82         req.identifier = identifier
  83         resp = self.clear_classifier(req)
  84 
  85 
  86     def createClassifier(self, identifier, class_type):
  87         req = ml_classifiers.srv.CreateClassifierRequest()
  88         req.identifier = identifier
  89         req.class_type = class_type
  90         resp = self.create_classifier(req)
  91 
  92 
  93     def loadClassifier(self, identifier, class_type, filename):
  94         req = ml_classifiers.srv.LoadClassifierRequest()
  95         req.identifier = identifier
  96         req.class_type = class_type
  97         req.filename = filename
  98         resp = self.load_classifier(req)
  99 
 100 
 101     def saveClassifier(self, identifier, filename):
 102         req = ml_classifiers.srv.SaveClassifierRequest()
 103         req.identifier = identifier
 104         req.filename = filename
 105         resp = self.save_classifier(req)
 106 
 107 
 108     def trainClassifier(self, identifier):
 109         req = ml_classifiers.srv.TrainClassifierRequest()
 110         req.identifier = identifier
 111         resp = self.train_classifier(req)
 112 
 113 
 114 if __name__ == '__main__':
 115     cw = ClassifierWrapper()
 116     cw.createClassifier('test','ml_classifiers/SVMClassifier')
 117 
 118     targs = ['1','1','2','2','3']
 119     pts = [[0.1,0.2],[0.3,0.1],[3.1,3.2],[3.3,4.1],[5.1,5.2]]
 120     cw.addClassDataPoints('test', targs, pts)
 121     cw.trainClassifier('test')
 122 
 123     testpts = [[0.0,0.0],[5.5,5.5],[2.9,3.6]]
 124     resp = cw.classifyPoints('test',testpts)
 125     print resp

Report a Bug

<<TracLink(REPO COMPONENT)>>


2019-12-07 12:49