mirror of
https://github.com/gsi-upm/senpy
synced 2024-11-22 08:12:27 +00:00
Change conversion to Euclidean distance
* Added neutral point (if present) Closes !gsi-upm/senpy#37 (Ian's)
This commit is contained in:
parent
6b843a4384
commit
00da75153a
@ -32,8 +32,17 @@ class CentroidConversion(EmotionConversionPlugin):
|
|||||||
nv1[aliases.get(k2, k2)] = v2
|
nv1[aliases.get(k2, k2)] = v2
|
||||||
ncentroids[aliases.get(k1, k1)] = nv1
|
ncentroids[aliases.get(k1, k1)] = nv1
|
||||||
info['centroids'] = ncentroids
|
info['centroids'] = ncentroids
|
||||||
|
|
||||||
super(CentroidConversion, self).__init__(info)
|
super(CentroidConversion, self).__init__(info)
|
||||||
|
|
||||||
|
self.dimensions = set()
|
||||||
|
for c in self.centroids.values():
|
||||||
|
self.dimensions.update(c.keys())
|
||||||
|
self.neutralPoints = self.get("neutralPoints", dict())
|
||||||
|
if not self.neutralPoints:
|
||||||
|
for i in self.dimensions:
|
||||||
|
self.neutralPoints[i] = self.get("neutralValue", 0)
|
||||||
|
|
||||||
def _forward_conversion(self, original):
|
def _forward_conversion(self, original):
|
||||||
"""Sum the VAD value of all categories found."""
|
"""Sum the VAD value of all categories found."""
|
||||||
res = Emotion()
|
res = Emotion()
|
||||||
@ -49,14 +58,18 @@ class CentroidConversion(EmotionConversionPlugin):
|
|||||||
|
|
||||||
def _backwards_conversion(self, original):
|
def _backwards_conversion(self, original):
|
||||||
"""Find the closest category"""
|
"""Find the closest category"""
|
||||||
dimensions = set(c.keys() for c in centroids.values())
|
centroids = self.centroids
|
||||||
neutralPoint = self.get("origin", None)
|
neutralPoints = self.neutralPoints
|
||||||
neutralPoint = {k:neutralPoint[k] if k in neturalPoint else 0}
|
dimensions = self.dimensions
|
||||||
|
|
||||||
|
def distance_k(centroid, original, k):
|
||||||
|
# k component of the distance between the value and a given centroid
|
||||||
|
return (centroid.get(k, neutralPoints[k]) - original.get(k, neutralPoints[k]))**2
|
||||||
|
|
||||||
def distance(centroid):
|
def distance(centroid):
|
||||||
return sum((centroid.get(k, neutralPoint[k]) - original.get(k, neutralPoint[k]))**2 for k in dimensions)
|
return sum(distance_k(centroid, original, k) for k in dimensions)
|
||||||
|
|
||||||
emotion = min(centroids, key=lambda x: distance(centroids[x])
|
emotion = min(centroids, key=lambda x: distance(centroids[x]))
|
||||||
|
|
||||||
result = Emotion(onyx__hasEmotionCategory=emotion)
|
result = Emotion(onyx__hasEmotionCategory=emotion)
|
||||||
return result
|
return result
|
||||||
|
Loading…
Reference in New Issue
Block a user