diff --git a/senpy/plugins/conversion/centroids.py b/senpy/plugins/conversion/centroids.py index 3479c50..427e111 100644 --- a/senpy/plugins/conversion/centroids.py +++ b/senpy/plugins/conversion/centroids.py @@ -32,8 +32,17 @@ class CentroidConversion(EmotionConversionPlugin): nv1[aliases.get(k2, k2)] = v2 ncentroids[aliases.get(k1, k1)] = nv1 info['centroids'] = ncentroids + super(CentroidConversion, self).__init__(info) + self.dimensions = set() + for c in self.centroids.values(): + self.dimensions.update(c.keys()) + self.neutralPoints = self.get("neutralPoints", dict()) + if not self.neutralPoints: + for i in self.dimensions: + self.neutralPoints[i] = self.get("neutralValue", 0) + def _forward_conversion(self, original): """Sum the VAD value of all categories found.""" res = Emotion() @@ -49,15 +58,19 @@ class CentroidConversion(EmotionConversionPlugin): def _backwards_conversion(self, original): """Find the closest category""" - dimensions = set(c.keys() for c in centroids.values()) - neutralPoint = self.get("origin", None) - neutralPoint = {k:neutralPoint[k] if k in neturalPoint else 0} - + centroids = self.centroids + neutralPoints = self.neutralPoints + dimensions = self.dimensions + + def distance_k(centroid, original, k): + # k component of the distance between the value and a given centroid + return (centroid.get(k, neutralPoints[k]) - original.get(k, neutralPoints[k]))**2 + def distance(centroid): - return sum((centroid.get(k, neutralPoint[k]) - original.get(k, neutralPoint[k]))**2 for k in dimensions) + return sum(distance_k(centroid, original, k) for k in dimensions) + + emotion = min(centroids, key=lambda x: distance(centroids[x])) - emotion = min(centroids, key=lambda x: distance(centroids[x]) - result = Emotion(onyx__hasEmotionCategory=emotion) return result