Skip to content

Commit 0735724

Browse files
authored
Implement multi-label SVM (scikit-multilearn#139)
Twin multi-label SVM implementation by Grzegorz
1 parent 5489ab3 commit 0735724

File tree

5 files changed

+221
-4
lines changed

5 files changed

+221
-4
lines changed

skmultilearn/adapt/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,18 @@
2525
+-----------------------------------------------+-----------------------------------------------------------+
2626
| :class:`~skmultilearn.adapt.MLARAM` | a multi-Label Hierarchical ARAM Neural Network |
2727
+-----------------------------------------------+-----------------------------------------------------------+
28+
| :class:`~skmultilearn.adapt.MLTSVM` | twin multi-Label Support Vector Machines |
29+
+-----------------------------------------------+-----------------------------------------------------------+
2830
2931
"""
3032

3133
from .brknn import BRkNNaClassifier, BRkNNbClassifier
3234
from .mlknn import MLkNN
3335
from .mlaram import MLARAM
36+
from .mltsvm import MLTSVM
3437

3538
__all__ = ["BRkNNaClassifier",
3639
"BRkNNbClassifier",
3740
"MLkNN",
38-
"MLARAM"]
41+
"MLARAM",
42+
"MLTSVM"]

skmultilearn/adapt/brknn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class BRkNNaClassifier(_BinaryRelevanceKNN):
122122
from sklearn.model_selection import GridSearchCV
123123
124124
parameters = {'k': range(1,3)}
125-
score = 'f1-macro
125+
score = 'f1_macro'
126126
127127
clf = GridSearchCV(BRkNNaClassifier(), parameters, scoring=score)
128128
clf.fit(X, y)

skmultilearn/adapt/mlknn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ class MLkNN(MLClassifierBase):
8181
from sklearn.model_selection import GridSearchCV
8282
8383
parameters = {'k': range(1,3), 's': [0.5, 0.7, 1.0]}
84-
score = 'f1-macro
84+
score = 'f1_macro'
8585
8686
clf = GridSearchCV(MLkNN(), parameters, scoring=score)
8787
clf.fit(X, y)
8888
89-
print clf.best_params_, clf.best_score_
89+
print (clf.best_params_, clf.best_score_)
9090
9191
# output
9292
({'k': 1, 's': 0.5}, 0.78988303374297597)

skmultilearn/adapt/mltsvm.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Authors: Grzegorz Kulakowski <[email protected]>
2+
# License: BSD 3 clause
3+
from skmultilearn.base import MLClassifierBase
4+
5+
import numpy as np
6+
import scipy.sparse as sp
7+
from scipy.linalg import norm
8+
from scipy.sparse.linalg import inv as inv_sparse
9+
from scipy.linalg import inv as inv_dense
10+
11+
12+
class MLTSVM(MLClassifierBase):
13+
"""Twin multi-Label Support Vector Machines
14+
15+
Parameters
16+
----------
17+
c_k : int
18+
the empirical risk penalty parameter that determines the trade-off between the loss terms
19+
sor_omega: float (default is 1.0)
20+
the smoothing parameter
21+
threshold : int (default is 1e-6)
22+
threshold above which a label should be assigned
23+
lambda_param : float (default is 1.0)
24+
the regularization parameter
25+
max_iteration : int (default is 500)
26+
maximum number of iterations to use in successive overrelaxation
27+
28+
29+
References
30+
----------
31+
32+
If you use this classifier please cite the original paper introducing the method:
33+
34+
.. code :: bibtex
35+
36+
@article{chen2016mltsvm,
37+
title={MLTSVM: a novel twin support vector machine to multi-label learning},
38+
author={Chen, Wei-Jie and Shao, Yuan-Hai and Li, Chun-Na and Deng, Nai-Yang},
39+
journal={Pattern Recognition},
40+
volume={52},
41+
pages={61--74},
42+
year={2016},
43+
publisher={Elsevier}
44+
}
45+
46+
47+
Examples
48+
--------
49+
50+
Here's a very simple example of using MLTSVM with a fixed number of neighbors:
51+
52+
.. code :: python
53+
54+
from skmultilearn.adapt import MLTSVM
55+
56+
classifier = MLTSVM(c_k = 2**-1)
57+
58+
# train
59+
classifier.fit(X_train, y_train)
60+
61+
# predict
62+
predictions = classifier.predict(X_test)
63+
64+
65+
You can also use :class:`~sklearn.model_selection.GridSearchCV` to find an optimal set of parameters:
66+
67+
.. code :: python
68+
69+
from skmultilearn.adapt import MLTSVM
70+
from sklearn.model_selection import GridSearchCV
71+
72+
parameters = {'c_k': [2**i for i in range(-5, 5, 2)]}
73+
score = 'f1-macro
74+
75+
clf = GridSearchCV(MLTSVM(), parameters, scoring=score)
76+
clf.fit(X, y)
77+
78+
print (clf.best_params_, clf.best_score_)
79+
80+
# output
81+
{'c_k': 0.03125} 0.347518217573
82+
83+
84+
"""
85+
86+
def __init__(self, c_k=0, sor_omega=1.0, threshold=1e-6, lambda_param=1.0, max_iteration=500):
87+
super(MLClassifierBase, self).__init__()
88+
self.max_iteration = max_iteration
89+
self.threshold = threshold
90+
self.lambda_param = lambda_param # TODO: possibility to add different lambda to different labels
91+
self.c_k = c_k
92+
self.sor_omega = sor_omega
93+
self.copyable_attrs = ['c_k', 'sor_omega', 'lambda_param', 'threshold', 'max_iteration']
94+
95+
def fit(self, X, Y):
96+
n_labels = Y.shape[1]
97+
m = X.shape[1] # Count of features
98+
self.wk_bk = np.zeros([n_labels, m + 1], dtype=float)
99+
100+
if sp.issparse(X):
101+
identity_matrix = sp.identity(m + 1)
102+
_inv = inv_sparse
103+
else:
104+
identity_matrix = np.identity(m + 1)
105+
_inv = inv_dense
106+
107+
X_bias = _hstack(X, np.ones((X.shape[0], 1), dtype=X.dtype))
108+
self.iteration_count = []
109+
for label in range(0, n_labels):
110+
# Calculate the parameter Q for overrelaxation
111+
H_k = _get_x_class_instances(X_bias, Y, label)
112+
G_k = _get_x_noclass_instances(X_bias, Y, label)
113+
Q_knoPrefixGk = _inv((H_k.T).dot(H_k) + self.lambda_param * identity_matrix).dot(G_k.T)
114+
Q_k = G_k.dot(Q_knoPrefixGk).A
115+
Q_k = (Q_k + Q_k.T) / 2.0
116+
117+
# Calculate other
118+
alpha_k = self._successive_overrelaxation(self.sor_omega, Q_k)
119+
if sp.issparse(X):
120+
self.wk_bk[label] = -Q_knoPrefixGk.dot(alpha_k).T
121+
else:
122+
self.wk_bk[label] = (-np.dot(Q_knoPrefixGk, alpha_k)).T
123+
124+
self.wk_norms = norm(self.wk_bk, axis=1)
125+
self.treshold = 1.0 / np.max(self.wk_norms)
126+
127+
def predict(self, X):
128+
X_with_bias = _hstack(X, np.ones((X.shape[0], 1), dtype=X.dtype))
129+
wk_norms_multiplicated = self.wk_norms[np.newaxis, :] # change to form [[wk1, wk2, ..., wkk]]
130+
all_distances = (-X_with_bias.dot(self.wk_bk.T)) / wk_norms_multiplicated
131+
predicted_y = np.where(all_distances < self.treshold, 1, 0)
132+
# TODO: It's possible to add condition to: add label if no labels is in row.
133+
return predicted_y
134+
135+
def _successive_overrelaxation(self, omegaW, Q):
136+
# Initialization
137+
D = np.diag(Q) # Only one dimension vector - is enough
138+
D_inv = 1.0 / D # D-1 simplify form
139+
small_l = Q.shape[1]
140+
oldnew_alpha = np.zeros([small_l, 1]) # buffer
141+
142+
is_not_enough = True
143+
was_going_down = False
144+
last_alfa_norm_change = -1
145+
146+
nr_iter = 0
147+
while is_not_enough: # do while
148+
oldAlpha = oldnew_alpha
149+
for j in range(0, small_l): # It's from last alpha to first
150+
oldnew_alpha[j] = oldAlpha[j] - omegaW * D_inv[j] * (Q[j, :].T.dot(oldnew_alpha) - 1)
151+
oldnew_alpha = oldnew_alpha.clip(0.0, self.c_k)
152+
alfa_norm_change = norm(oldnew_alpha - oldAlpha)
153+
154+
if not was_going_down and last_alfa_norm_change > alfa_norm_change:
155+
was_going_down = True
156+
is_not_enough = alfa_norm_change > self.threshold and \
157+
nr_iter < self.max_iteration \
158+
and ((not was_going_down) or last_alfa_norm_change > alfa_norm_change)
159+
# TODO: maybe add any(oldnew_alpha != oldAlpha)
160+
161+
last_alfa_norm_change = alfa_norm_change
162+
nr_iter += 1
163+
self.iteration_count.append(nr_iter)
164+
return oldnew_alpha
165+
166+
167+
def _get_x_noclass_instances(X, Y, label_class):
168+
if sp.issparse(Y):
169+
indices = np.where(Y[:, 1].A == 0)[0]
170+
else:
171+
indices = np.where(Y[:, 1] == 0)[0]
172+
return X[indices, :]
173+
174+
175+
def _get_x_class_instances(X, Y, label_class):
176+
if sp.issparse(Y):
177+
indices = Y[:, label_class].nonzero()[0]
178+
else:
179+
indices = np.nonzero(Y[:, label_class])[0]
180+
return X[indices, :]
181+
182+
183+
def _hstack(X, Y):
184+
if sp.issparse(X):
185+
return sp.hstack([X, Y], format=X.format)
186+
else:
187+
return np.hstack([X, Y])
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
3+
from skmultilearn.adapt import MLTSVM
4+
from skmultilearn.tests.classifier_basetest import ClassifierBaseTest
5+
6+
7+
class MLTSVMTest(ClassifierBaseTest):
8+
TEST_NEIGHBORS = 3
9+
10+
def classifiers(self):
11+
return [MLTSVM(c_k=2**-4)]
12+
13+
def test_if_mlknn_classification_works_on_sparse_input(self):
14+
for classifier in self.classifiers():
15+
self.assertClassifierWorksWithSparsity(classifier, 'sparse')
16+
17+
def test_if_mlknn_classification_works_on_dense_input(self):
18+
for classifier in self.classifiers():
19+
self.assertClassifierWorksWithSparsity(classifier, 'dense')
20+
21+
def test_if_mlknn_works_with_cross_validation(self):
22+
for classifier in self.classifiers():
23+
self.assertClassifierWorksWithCV(classifier)
24+
25+
if __name__ == '__main__':
26+
unittest.main()

0 commit comments

Comments
 (0)