diff --git a/senpy/plugins/conversion/centroids.py b/senpy/plugins/conversion/centroids.py index 2dd1c97..2c0e735 100644 --- a/senpy/plugins/conversion/centroids.py +++ b/senpy/plugins/conversion/centroids.py @@ -1,5 +1,6 @@ from senpy.plugins import EmotionConversionPlugin from senpy.models import EmotionSet, Emotion, Error +from collections import defaultdict import logging logger = logging.getLogger(__name__) @@ -37,14 +38,22 @@ class CentroidConversion(EmotionConversionPlugin): def _forward_conversion(self, original): """Sum the VAD value of all categories found.""" res = Emotion() + totalIntensities = defaultdict(float) for e in original.onyx__hasEmotion: category = e.onyx__hasEmotionCategory + intensity = e.get("onyx__hasEmotionIntensity",1) if category in self.centroids: + totalIntensities[category] += intensity for dim, value in self.centroids[category].items(): try: - res[dim] += value + res[dim] += value * intensity except Exception: - res[dim] = value + res[dim] = value * intensity + for dim,intensity in totalIntensities.items(): + if intensity != 0: + res[dim] /= intensity + else: + res[dim] = self.centroids.get('neutral', {dim:0})[dim] return res def _backwards_conversion(self, original):