mirror of
https://github.com/gsi-upm/senpy
synced 2025-08-23 18:12:20 +00:00
Add evaluation tests
This commit is contained in:
@@ -10,6 +10,8 @@ from senpy.models import Results, Entry, EmotionSet, Emotion, Plugins
|
||||
from senpy import plugins
|
||||
from senpy.plugins.conversion.emotion.centroids import CentroidConversion
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ShelfDummyPlugin(plugins.SentimentPlugin, plugins.ShelfMixin):
|
||||
'''Dummy plugin for tests.'''
|
||||
@@ -212,7 +214,7 @@ class PluginsTest(TestCase):
|
||||
def input(self, entry, **kwargs):
|
||||
return entry.text
|
||||
|
||||
def predict(self, input):
|
||||
def predict_one(self, input):
|
||||
return 'SIGN' in input
|
||||
|
||||
def output(self, output, entry, **kwargs):
|
||||
@@ -242,7 +244,7 @@ class PluginsTest(TestCase):
|
||||
|
||||
mappings = {'happy': 'marl:Positive', 'sad': 'marl:Negative'}
|
||||
|
||||
def predict(self, input, **kwargs):
|
||||
def predict_one(self, input, **kwargs):
|
||||
return 'happy' if ':)' in input else 'sad'
|
||||
|
||||
test_cases = [
|
||||
@@ -309,6 +311,40 @@ class PluginsTest(TestCase):
|
||||
res = c._backwards_conversion(e)
|
||||
assert res["onyx:hasEmotionCategory"] == "c2"
|
||||
|
||||
def test_evaluation(self):
|
||||
testdata = []
|
||||
for i in range(50):
|
||||
testdata.append(["good", 1])
|
||||
for i in range(50):
|
||||
testdata.append(["bad", 0])
|
||||
dataset = pd.DataFrame(testdata, columns=['text', 'polarity'])
|
||||
|
||||
class DummyPlugin(plugins.TextBox):
|
||||
description = 'Plugin to test evaluation'
|
||||
version = 0
|
||||
|
||||
def predict_one(self, input):
|
||||
return 0
|
||||
|
||||
class SmartPlugin(plugins.TextBox):
|
||||
description = 'Plugin to test evaluation'
|
||||
version = 0
|
||||
|
||||
def predict_one(self, input):
|
||||
if input == 'good':
|
||||
return 1
|
||||
return 0
|
||||
|
||||
dpipe = DummyPlugin()
|
||||
results = plugins.evaluate(datasets={'testdata': dataset}, plugins=[dpipe], flatten=True)
|
||||
dumb_metrics = results[0].metrics[0]
|
||||
assert abs(dumb_metrics['accuracy'] - 0.5) < 0.01
|
||||
|
||||
spipe = SmartPlugin()
|
||||
results = plugins.evaluate(datasets={'testdata': dataset}, plugins=[spipe], flatten=True)
|
||||
smart_metrics = results[0].metrics[0]
|
||||
assert abs(smart_metrics['accuracy'] - 1) < 0.01
|
||||
|
||||
|
||||
def make_mini_test(fpath):
|
||||
def mini_test(self):
|
||||
|
Reference in New Issue
Block a user