516 lines
17 KiB
Python
516 lines
17 KiB
Python
"""
|
|
Tests for Backtesting Module.
|
|
|
|
This module contains unit tests for the backtesting functionality
|
|
including accuracy calculation, comparison logic, and export formats.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from app.ml.backtesting import (
|
|
validate_accuracy,
|
|
compare_prediction,
|
|
run_backtesting_single_match,
|
|
run_backtesting_batch,
|
|
export_to_json,
|
|
export_to_csv,
|
|
export_to_html,
|
|
filter_matches_by_league,
|
|
filter_matches_by_period,
|
|
ACCURACY_VALIDATED_THRESHOLD,
|
|
ACCURACY_ALERT_THRESHOLD
|
|
)
|
|
|
|
|
|
class TestValidateAccuracy:
|
|
"""Tests for validate_accuracy function."""
|
|
|
|
def test_validate_accuracy_above_threshold(self):
|
|
"""Test validation when accuracy >= 60%."""
|
|
result = validate_accuracy(65.0)
|
|
assert result == 'VALIDATED'
|
|
|
|
def test_validate_accuracy_at_threshold(self):
|
|
"""Test validation when accuracy == 60%."""
|
|
result = validate_accuracy(60.0)
|
|
assert result == 'VALIDATED'
|
|
|
|
def test_validate_accuracy_below_target(self):
|
|
"""Test validation when 55% <= accuracy < 60%."""
|
|
result = validate_accuracy(58.0)
|
|
assert result == 'BELOW_TARGET'
|
|
|
|
def test_validate_accuracy_alert(self):
|
|
"""Test validation when accuracy < 55%."""
|
|
result = validate_accuracy(50.0)
|
|
assert result == 'REVISION_REQUIRED'
|
|
|
|
def test_validate_accuracy_boundary(self):
|
|
"""Test validation at boundary 55%."""
|
|
result = validate_accuracy(55.0)
|
|
assert result == 'BELOW_TARGET'
|
|
|
|
def test_validate_accuracy_extreme_high(self):
|
|
"""Test validation with perfect accuracy."""
|
|
result = validate_accuracy(100.0)
|
|
assert result == 'VALIDATED'
|
|
|
|
def test_validate_accuracy_zero(self):
|
|
"""Test validation with zero accuracy."""
|
|
result = validate_accuracy(0.0)
|
|
assert result == 'REVISION_REQUIRED'
|
|
|
|
def test_validate_thresholds_constants(self):
|
|
"""Test that threshold constants are properly defined."""
|
|
assert ACCURACY_VALIDATED_THRESHOLD == 60.0
|
|
assert ACCURACY_ALERT_THRESHOLD == 55.0
|
|
|
|
|
|
class TestComparePrediction:
|
|
"""Tests for compare_prediction function."""
|
|
|
|
def test_compare_home_correct(self):
|
|
"""Test comparison when home prediction is correct."""
|
|
result = compare_prediction('home', 'home')
|
|
assert result is True
|
|
|
|
def test_compare_away_correct(self):
|
|
"""Test comparison when away prediction is correct."""
|
|
result = compare_prediction('away', 'away')
|
|
assert result is True
|
|
|
|
def test_compare_draw_correct(self):
|
|
"""Test comparison when draw prediction is correct."""
|
|
result = compare_prediction('draw', 'draw')
|
|
assert result is True
|
|
|
|
def test_compare_home_incorrect(self):
|
|
"""Test comparison when home prediction is incorrect."""
|
|
result = compare_prediction('home', 'away')
|
|
assert result is False
|
|
|
|
def test_compare_case_insensitive(self):
|
|
"""Test that comparison is case insensitive."""
|
|
result1 = compare_prediction('HOME', 'home')
|
|
result2 = compare_prediction('Home', 'home')
|
|
result3 = compare_prediction('home', 'HOME')
|
|
assert result1 is True
|
|
assert result2 is True
|
|
assert result3 is True
|
|
|
|
|
|
class TestRunBacktestingSingleMatch:
|
|
"""Tests for run_backtesting_single_match function."""
|
|
|
|
def test_single_match_correct_prediction(self):
|
|
"""Test backtesting for a single match with correct prediction."""
|
|
result = run_backtesting_single_match(
|
|
match_id=1,
|
|
home_team='PSG',
|
|
away_team='OM',
|
|
home_energy=65.0,
|
|
away_energy=45.0,
|
|
actual_winner='home'
|
|
)
|
|
|
|
assert result['match_id'] == 1
|
|
assert result['home_team'] == 'PSG'
|
|
assert result['away_team'] == 'OM'
|
|
assert result['home_energy'] == 65.0
|
|
assert result['away_energy'] == 45.0
|
|
assert result['actual_winner'] == 'home'
|
|
assert result['correct'] is True
|
|
assert 'prediction' in result
|
|
assert result['prediction']['predicted_winner'] == 'home'
|
|
|
|
def test_single_match_incorrect_prediction(self):
|
|
"""Test backtesting for a single match with incorrect prediction."""
|
|
result = run_backtesting_single_match(
|
|
match_id=2,
|
|
home_team='PSG',
|
|
away_team='OM',
|
|
home_energy=65.0,
|
|
away_energy=45.0,
|
|
actual_winner='away'
|
|
)
|
|
|
|
assert result['match_id'] == 2
|
|
assert result['correct'] is False
|
|
assert result['prediction']['predicted_winner'] != 'away'
|
|
|
|
def test_single_match_with_equal_energy(self):
|
|
"""Test backtesting for a match with equal energy scores."""
|
|
result = run_backtesting_single_match(
|
|
match_id=3,
|
|
home_team='PSG',
|
|
away_team='OM',
|
|
home_energy=50.0,
|
|
away_energy=50.0,
|
|
actual_winner='draw'
|
|
)
|
|
|
|
assert result['prediction']['predicted_winner'] == 'draw'
|
|
assert result['correct'] is True
|
|
|
|
|
|
class TestRunBacktestingBatch:
|
|
"""Tests for run_backtesting_batch function."""
|
|
|
|
def test_batch_all_correct(self):
|
|
"""Test backtesting with all correct predictions."""
|
|
matches = [
|
|
{
|
|
'match_id': 1,
|
|
'home_team': 'PSG',
|
|
'away_team': 'OM',
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'actual_winner': 'home',
|
|
'league': 'Ligue 1'
|
|
},
|
|
{
|
|
'match_id': 2,
|
|
'home_team': 'Lyon',
|
|
'away_team': 'Monaco',
|
|
'home_energy': 45.0,
|
|
'away_energy': 65.0,
|
|
'actual_winner': 'away',
|
|
'league': 'Ligue 1'
|
|
}
|
|
]
|
|
|
|
result = run_backtesting_batch(matches)
|
|
|
|
assert result['total_matches'] == 2
|
|
assert result['correct_predictions'] == 2
|
|
assert result['incorrect_predictions'] == 0
|
|
assert result['accuracy'] == 100.0
|
|
assert result['status'] == 'VALIDATED'
|
|
assert len(result['results']) == 2
|
|
|
|
def test_batch_mixed_results(self):
|
|
"""Test backtesting with mixed correct/incorrect predictions."""
|
|
matches = [
|
|
{
|
|
'match_id': 1,
|
|
'home_team': 'PSG',
|
|
'away_team': 'OM',
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'actual_winner': 'home',
|
|
'league': 'Ligue 1'
|
|
},
|
|
{
|
|
'match_id': 2,
|
|
'home_team': 'Lyon',
|
|
'away_team': 'Monaco',
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'actual_winner': 'away',
|
|
'league': 'Ligue 1'
|
|
}
|
|
]
|
|
|
|
result = run_backtesting_batch(matches)
|
|
|
|
assert result['total_matches'] == 2
|
|
assert result['correct_predictions'] == 1
|
|
assert result['incorrect_predictions'] == 1
|
|
assert result['accuracy'] == 50.0
|
|
assert result['status'] == 'REVISION_REQUIRED'
|
|
|
|
def test_batch_with_leagues(self):
|
|
"""Test backtracking with multiple leagues."""
|
|
matches = [
|
|
{
|
|
'match_id': 1,
|
|
'home_team': 'PSG',
|
|
'away_team': 'OM',
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'actual_winner': 'home',
|
|
'league': 'Ligue 1'
|
|
},
|
|
{
|
|
'match_id': 2,
|
|
'home_team': 'Man City',
|
|
'away_team': 'Liverpool',
|
|
'home_energy': 70.0,
|
|
'away_energy': 50.0,
|
|
'actual_winner': 'home',
|
|
'league': 'Premier League'
|
|
}
|
|
]
|
|
|
|
result = run_backtesting_batch(matches)
|
|
|
|
assert 'metrics_by_league' in result
|
|
assert 'Ligue 1' in result['metrics_by_league']
|
|
assert 'Premier League' in result['metrics_by_league']
|
|
assert result['metrics_by_league']['Ligue 1']['total'] == 1
|
|
assert result['metrics_by_league']['Premier League']['total'] == 1
|
|
|
|
def test_batch_empty(self):
|
|
"""Test backtracking with no matches."""
|
|
result = run_backtesting_batch([])
|
|
|
|
assert result['total_matches'] == 0
|
|
assert result['correct_predictions'] == 0
|
|
assert result['incorrect_predictions'] == 0
|
|
assert result['accuracy'] == 0.0
|
|
|
|
def test_batch_missing_required_field(self):
|
|
"""Test backtracking with missing required field raises error."""
|
|
matches = [
|
|
{
|
|
'match_id': 1,
|
|
'home_team': 'PSG',
|
|
# Missing 'away_team'
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'actual_winner': 'home'
|
|
}
|
|
]
|
|
|
|
with pytest.raises(ValueError, match="missing required fields"):
|
|
run_backtesting_batch(matches)
|
|
|
|
def test_batch_with_dates(self):
|
|
"""Test backtracking with match dates."""
|
|
match_date = datetime(2025, 1, 15, 20, 0, 0)
|
|
matches = [
|
|
{
|
|
'match_id': 1,
|
|
'home_team': 'PSG',
|
|
'away_team': 'OM',
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'actual_winner': 'home',
|
|
'league': 'Ligue 1',
|
|
'date': match_date
|
|
}
|
|
]
|
|
|
|
result = run_backtesting_batch(matches)
|
|
|
|
assert result['results'][0]['date'] == match_date.isoformat()
|
|
|
|
|
|
class TestExportFormats:
|
|
"""Tests for export functions."""
|
|
|
|
def test_export_to_json(self):
|
|
"""Test JSON export format."""
|
|
backtesting_result = {
|
|
'total_matches': 2,
|
|
'correct_predictions': 1,
|
|
'incorrect_predictions': 1,
|
|
'accuracy': 50.0,
|
|
'status': 'REVISION_REQUIRED',
|
|
'results': [],
|
|
'metrics_by_league': {},
|
|
'timestamp': '2026-01-17T10:00:00Z',
|
|
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
|
|
}
|
|
|
|
json_output = export_to_json(backtesting_result)
|
|
|
|
assert isinstance(json_output, str)
|
|
assert 'total_matches' in json_output
|
|
assert 'accuracy' in json_output
|
|
|
|
def test_export_to_csv(self):
|
|
"""Test CSV export format."""
|
|
backtesting_result = {
|
|
'total_matches': 1,
|
|
'correct_predictions': 1,
|
|
'incorrect_predictions': 0,
|
|
'accuracy': 100.0,
|
|
'status': 'VALIDATED',
|
|
'results': [
|
|
{
|
|
'match_id': 1,
|
|
'league': 'Ligue 1',
|
|
'date': '2026-01-15T20:00:00Z',
|
|
'home_team': 'PSG',
|
|
'away_team': 'OM',
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
|
|
'actual_winner': 'home',
|
|
'correct': True
|
|
}
|
|
],
|
|
'metrics_by_league': {},
|
|
'timestamp': '2026-01-17T10:00:00Z',
|
|
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
|
|
}
|
|
|
|
csv_output = export_to_csv(backtesting_result)
|
|
|
|
assert isinstance(csv_output, str)
|
|
assert 'match_id' in csv_output
|
|
assert 'PSG' in csv_output
|
|
assert 'OM' in csv_output
|
|
|
|
def test_export_to_html(self):
|
|
"""Test HTML export format."""
|
|
backtesting_result = {
|
|
'total_matches': 1,
|
|
'correct_predictions': 1,
|
|
'incorrect_predictions': 0,
|
|
'accuracy': 100.0,
|
|
'status': 'VALIDATED',
|
|
'results': [
|
|
{
|
|
'match_id': 1,
|
|
'league': 'Ligue 1',
|
|
'date': '2026-01-15T20:00:00Z',
|
|
'home_team': 'PSG',
|
|
'away_team': 'OM',
|
|
'home_energy': 65.0,
|
|
'away_energy': 45.0,
|
|
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
|
|
'actual_winner': 'home',
|
|
'correct': True
|
|
}
|
|
],
|
|
'metrics_by_league': {},
|
|
'timestamp': '2026-01-17T10:00:00Z',
|
|
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
|
|
}
|
|
|
|
html_output = export_to_html(backtesting_result)
|
|
|
|
assert isinstance(html_output, str)
|
|
assert '<html>' in html_output
|
|
assert '</html>' in html_output
|
|
assert 'Backtesting Report' in html_output
|
|
assert '100.0%' in html_output
|
|
assert 'VALIDATED' in html_output
|
|
|
|
|
|
class TestFilterMatchesByLeague:
|
|
"""Tests for filter_matches_by_league function."""
|
|
|
|
def test_filter_by_single_league(self):
|
|
"""Test filtering by a single league."""
|
|
matches = [
|
|
{'league': 'Ligue 1', 'match_id': 1},
|
|
{'league': 'Premier League', 'match_id': 2},
|
|
{'league': 'Ligue 1', 'match_id': 3}
|
|
]
|
|
|
|
filtered = filter_matches_by_league(matches, ['Ligue 1'])
|
|
|
|
assert len(filtered) == 2
|
|
assert all(m['league'] == 'Ligue 1' for m in filtered)
|
|
|
|
def test_filter_by_multiple_leagues(self):
|
|
"""Test filtering by multiple leagues."""
|
|
matches = [
|
|
{'league': 'Ligue 1', 'match_id': 1},
|
|
{'league': 'Premier League', 'match_id': 2},
|
|
{'league': 'La Liga', 'match_id': 3}
|
|
]
|
|
|
|
filtered = filter_matches_by_league(matches, ['Ligue 1', 'Premier League'])
|
|
|
|
assert len(filtered) == 2
|
|
assert filtered[0]['league'] == 'Ligue 1'
|
|
assert filtered[1]['league'] == 'Premier League'
|
|
|
|
def test_filter_no_leagues(self):
|
|
"""Test that empty leagues list returns all matches."""
|
|
matches = [
|
|
{'league': 'Ligue 1', 'match_id': 1},
|
|
{'league': 'Premier League', 'match_id': 2}
|
|
]
|
|
|
|
filtered = filter_matches_by_league(matches, [])
|
|
|
|
assert len(filtered) == 2
|
|
|
|
def test_filter_none_leagues(self):
|
|
"""Test that None leagues list returns all matches."""
|
|
matches = [
|
|
{'league': 'Ligue 1', 'match_id': 1},
|
|
{'league': 'Premier League', 'match_id': 2}
|
|
]
|
|
|
|
filtered = filter_matches_by_league(matches, None)
|
|
|
|
assert len(filtered) == 2
|
|
|
|
|
|
class TestFilterMatchesByPeriod:
|
|
"""Tests for filter_matches_by_period function."""
|
|
|
|
def test_filter_by_start_date(self):
|
|
"""Test filtering by start date."""
|
|
matches = [
|
|
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
|
{'date': datetime(2025, 1, 20), 'match_id': 2},
|
|
{'date': datetime(2025, 1, 5), 'match_id': 3}
|
|
]
|
|
|
|
start_date = datetime(2025, 1, 15)
|
|
filtered = filter_matches_by_period(matches, start_date=start_date)
|
|
|
|
assert len(filtered) == 1
|
|
assert filtered[0]['match_id'] == 2
|
|
|
|
def test_filter_by_end_date(self):
|
|
"""Test filtering by end date."""
|
|
matches = [
|
|
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
|
{'date': datetime(2025, 1, 20), 'match_id': 2},
|
|
{'date': datetime(2025, 1, 5), 'match_id': 3}
|
|
]
|
|
|
|
end_date = datetime(2025, 1, 15)
|
|
filtered = filter_matches_by_period(matches, end_date=end_date)
|
|
|
|
assert len(filtered) == 2
|
|
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
|
|
|
|
def test_filter_by_date_range(self):
|
|
"""Test filtering by date range."""
|
|
matches = [
|
|
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
|
{'date': datetime(2025, 1, 20), 'match_id': 2},
|
|
{'date': datetime(2025, 1, 15), 'match_id': 3},
|
|
{'date': datetime(2025, 1, 5), 'match_id': 4}
|
|
]
|
|
|
|
start_date = datetime(2025, 1, 10)
|
|
end_date = datetime(2025, 1, 15)
|
|
filtered = filter_matches_by_period(matches, start_date=start_date, end_date=end_date)
|
|
|
|
assert len(filtered) == 2
|
|
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
|
|
|
|
def test_filter_no_dates(self):
|
|
"""Test that None dates return all matches."""
|
|
matches = [
|
|
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
|
{'date': datetime(2025, 1, 20), 'match_id': 2}
|
|
]
|
|
|
|
filtered = filter_matches_by_period(matches, start_date=None, end_date=None)
|
|
|
|
assert len(filtered) == 2
|
|
|
|
def test_filter_no_date_field(self):
|
|
"""Test matches without date field are excluded when filtering."""
|
|
matches = [
|
|
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
|
{'match_id': 2} # No date field
|
|
]
|
|
|
|
start_date = datetime(2025, 1, 1)
|
|
filtered = filter_matches_by_period(matches, start_date=start_date)
|
|
|
|
assert len(filtered) == 1
|
|
assert filtered[0]['match_id'] == 1
|