chartbastan/backend/tests/test_backtesting.py
2026-02-01 09:31:38 +01:00

516 lines
17 KiB
Python

"""
Tests for Backtesting Module.
This module contains unit tests for the backtesting functionality
including accuracy calculation, comparison logic, and export formats.
"""
import pytest
from datetime import datetime
from app.ml.backtesting import (
validate_accuracy,
compare_prediction,
run_backtesting_single_match,
run_backtesting_batch,
export_to_json,
export_to_csv,
export_to_html,
filter_matches_by_league,
filter_matches_by_period,
ACCURACY_VALIDATED_THRESHOLD,
ACCURACY_ALERT_THRESHOLD
)
class TestValidateAccuracy:
"""Tests for validate_accuracy function."""
def test_validate_accuracy_above_threshold(self):
"""Test validation when accuracy >= 60%."""
result = validate_accuracy(65.0)
assert result == 'VALIDATED'
def test_validate_accuracy_at_threshold(self):
"""Test validation when accuracy == 60%."""
result = validate_accuracy(60.0)
assert result == 'VALIDATED'
def test_validate_accuracy_below_target(self):
"""Test validation when 55% <= accuracy < 60%."""
result = validate_accuracy(58.0)
assert result == 'BELOW_TARGET'
def test_validate_accuracy_alert(self):
"""Test validation when accuracy < 55%."""
result = validate_accuracy(50.0)
assert result == 'REVISION_REQUIRED'
def test_validate_accuracy_boundary(self):
"""Test validation at boundary 55%."""
result = validate_accuracy(55.0)
assert result == 'BELOW_TARGET'
def test_validate_accuracy_extreme_high(self):
"""Test validation with perfect accuracy."""
result = validate_accuracy(100.0)
assert result == 'VALIDATED'
def test_validate_accuracy_zero(self):
"""Test validation with zero accuracy."""
result = validate_accuracy(0.0)
assert result == 'REVISION_REQUIRED'
def test_validate_thresholds_constants(self):
"""Test that threshold constants are properly defined."""
assert ACCURACY_VALIDATED_THRESHOLD == 60.0
assert ACCURACY_ALERT_THRESHOLD == 55.0
class TestComparePrediction:
"""Tests for compare_prediction function."""
def test_compare_home_correct(self):
"""Test comparison when home prediction is correct."""
result = compare_prediction('home', 'home')
assert result is True
def test_compare_away_correct(self):
"""Test comparison when away prediction is correct."""
result = compare_prediction('away', 'away')
assert result is True
def test_compare_draw_correct(self):
"""Test comparison when draw prediction is correct."""
result = compare_prediction('draw', 'draw')
assert result is True
def test_compare_home_incorrect(self):
"""Test comparison when home prediction is incorrect."""
result = compare_prediction('home', 'away')
assert result is False
def test_compare_case_insensitive(self):
"""Test that comparison is case insensitive."""
result1 = compare_prediction('HOME', 'home')
result2 = compare_prediction('Home', 'home')
result3 = compare_prediction('home', 'HOME')
assert result1 is True
assert result2 is True
assert result3 is True
class TestRunBacktestingSingleMatch:
"""Tests for run_backtesting_single_match function."""
def test_single_match_correct_prediction(self):
"""Test backtesting for a single match with correct prediction."""
result = run_backtesting_single_match(
match_id=1,
home_team='PSG',
away_team='OM',
home_energy=65.0,
away_energy=45.0,
actual_winner='home'
)
assert result['match_id'] == 1
assert result['home_team'] == 'PSG'
assert result['away_team'] == 'OM'
assert result['home_energy'] == 65.0
assert result['away_energy'] == 45.0
assert result['actual_winner'] == 'home'
assert result['correct'] is True
assert 'prediction' in result
assert result['prediction']['predicted_winner'] == 'home'
def test_single_match_incorrect_prediction(self):
"""Test backtesting for a single match with incorrect prediction."""
result = run_backtesting_single_match(
match_id=2,
home_team='PSG',
away_team='OM',
home_energy=65.0,
away_energy=45.0,
actual_winner='away'
)
assert result['match_id'] == 2
assert result['correct'] is False
assert result['prediction']['predicted_winner'] != 'away'
def test_single_match_with_equal_energy(self):
"""Test backtesting for a match with equal energy scores."""
result = run_backtesting_single_match(
match_id=3,
home_team='PSG',
away_team='OM',
home_energy=50.0,
away_energy=50.0,
actual_winner='draw'
)
assert result['prediction']['predicted_winner'] == 'draw'
assert result['correct'] is True
class TestRunBacktestingBatch:
"""Tests for run_backtesting_batch function."""
def test_batch_all_correct(self):
"""Test backtesting with all correct predictions."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1'
},
{
'match_id': 2,
'home_team': 'Lyon',
'away_team': 'Monaco',
'home_energy': 45.0,
'away_energy': 65.0,
'actual_winner': 'away',
'league': 'Ligue 1'
}
]
result = run_backtesting_batch(matches)
assert result['total_matches'] == 2
assert result['correct_predictions'] == 2
assert result['incorrect_predictions'] == 0
assert result['accuracy'] == 100.0
assert result['status'] == 'VALIDATED'
assert len(result['results']) == 2
def test_batch_mixed_results(self):
"""Test backtesting with mixed correct/incorrect predictions."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1'
},
{
'match_id': 2,
'home_team': 'Lyon',
'away_team': 'Monaco',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'away',
'league': 'Ligue 1'
}
]
result = run_backtesting_batch(matches)
assert result['total_matches'] == 2
assert result['correct_predictions'] == 1
assert result['incorrect_predictions'] == 1
assert result['accuracy'] == 50.0
assert result['status'] == 'REVISION_REQUIRED'
def test_batch_with_leagues(self):
"""Test backtracking with multiple leagues."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1'
},
{
'match_id': 2,
'home_team': 'Man City',
'away_team': 'Liverpool',
'home_energy': 70.0,
'away_energy': 50.0,
'actual_winner': 'home',
'league': 'Premier League'
}
]
result = run_backtesting_batch(matches)
assert 'metrics_by_league' in result
assert 'Ligue 1' in result['metrics_by_league']
assert 'Premier League' in result['metrics_by_league']
assert result['metrics_by_league']['Ligue 1']['total'] == 1
assert result['metrics_by_league']['Premier League']['total'] == 1
def test_batch_empty(self):
"""Test backtracking with no matches."""
result = run_backtesting_batch([])
assert result['total_matches'] == 0
assert result['correct_predictions'] == 0
assert result['incorrect_predictions'] == 0
assert result['accuracy'] == 0.0
def test_batch_missing_required_field(self):
"""Test backtracking with missing required field raises error."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
# Missing 'away_team'
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home'
}
]
with pytest.raises(ValueError, match="missing required fields"):
run_backtesting_batch(matches)
def test_batch_with_dates(self):
"""Test backtracking with match dates."""
match_date = datetime(2025, 1, 15, 20, 0, 0)
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1',
'date': match_date
}
]
result = run_backtesting_batch(matches)
assert result['results'][0]['date'] == match_date.isoformat()
class TestExportFormats:
"""Tests for export functions."""
def test_export_to_json(self):
"""Test JSON export format."""
backtesting_result = {
'total_matches': 2,
'correct_predictions': 1,
'incorrect_predictions': 1,
'accuracy': 50.0,
'status': 'REVISION_REQUIRED',
'results': [],
'metrics_by_league': {},
'timestamp': '2026-01-17T10:00:00Z',
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
}
json_output = export_to_json(backtesting_result)
assert isinstance(json_output, str)
assert 'total_matches' in json_output
assert 'accuracy' in json_output
def test_export_to_csv(self):
"""Test CSV export format."""
backtesting_result = {
'total_matches': 1,
'correct_predictions': 1,
'incorrect_predictions': 0,
'accuracy': 100.0,
'status': 'VALIDATED',
'results': [
{
'match_id': 1,
'league': 'Ligue 1',
'date': '2026-01-15T20:00:00Z',
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
'actual_winner': 'home',
'correct': True
}
],
'metrics_by_league': {},
'timestamp': '2026-01-17T10:00:00Z',
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
}
csv_output = export_to_csv(backtesting_result)
assert isinstance(csv_output, str)
assert 'match_id' in csv_output
assert 'PSG' in csv_output
assert 'OM' in csv_output
def test_export_to_html(self):
"""Test HTML export format."""
backtesting_result = {
'total_matches': 1,
'correct_predictions': 1,
'incorrect_predictions': 0,
'accuracy': 100.0,
'status': 'VALIDATED',
'results': [
{
'match_id': 1,
'league': 'Ligue 1',
'date': '2026-01-15T20:00:00Z',
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
'actual_winner': 'home',
'correct': True
}
],
'metrics_by_league': {},
'timestamp': '2026-01-17T10:00:00Z',
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
}
html_output = export_to_html(backtesting_result)
assert isinstance(html_output, str)
assert '<html>' in html_output
assert '</html>' in html_output
assert 'Backtesting Report' in html_output
assert '100.0%' in html_output
assert 'VALIDATED' in html_output
class TestFilterMatchesByLeague:
"""Tests for filter_matches_by_league function."""
def test_filter_by_single_league(self):
"""Test filtering by a single league."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2},
{'league': 'Ligue 1', 'match_id': 3}
]
filtered = filter_matches_by_league(matches, ['Ligue 1'])
assert len(filtered) == 2
assert all(m['league'] == 'Ligue 1' for m in filtered)
def test_filter_by_multiple_leagues(self):
"""Test filtering by multiple leagues."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2},
{'league': 'La Liga', 'match_id': 3}
]
filtered = filter_matches_by_league(matches, ['Ligue 1', 'Premier League'])
assert len(filtered) == 2
assert filtered[0]['league'] == 'Ligue 1'
assert filtered[1]['league'] == 'Premier League'
def test_filter_no_leagues(self):
"""Test that empty leagues list returns all matches."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2}
]
filtered = filter_matches_by_league(matches, [])
assert len(filtered) == 2
def test_filter_none_leagues(self):
"""Test that None leagues list returns all matches."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2}
]
filtered = filter_matches_by_league(matches, None)
assert len(filtered) == 2
class TestFilterMatchesByPeriod:
"""Tests for filter_matches_by_period function."""
def test_filter_by_start_date(self):
"""Test filtering by start date."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2},
{'date': datetime(2025, 1, 5), 'match_id': 3}
]
start_date = datetime(2025, 1, 15)
filtered = filter_matches_by_period(matches, start_date=start_date)
assert len(filtered) == 1
assert filtered[0]['match_id'] == 2
def test_filter_by_end_date(self):
"""Test filtering by end date."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2},
{'date': datetime(2025, 1, 5), 'match_id': 3}
]
end_date = datetime(2025, 1, 15)
filtered = filter_matches_by_period(matches, end_date=end_date)
assert len(filtered) == 2
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
def test_filter_by_date_range(self):
"""Test filtering by date range."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2},
{'date': datetime(2025, 1, 15), 'match_id': 3},
{'date': datetime(2025, 1, 5), 'match_id': 4}
]
start_date = datetime(2025, 1, 10)
end_date = datetime(2025, 1, 15)
filtered = filter_matches_by_period(matches, start_date=start_date, end_date=end_date)
assert len(filtered) == 2
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
def test_filter_no_dates(self):
"""Test that None dates return all matches."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2}
]
filtered = filter_matches_by_period(matches, start_date=None, end_date=None)
assert len(filtered) == 2
def test_filter_no_date_field(self):
"""Test matches without date field are excluded when filtering."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'match_id': 2} # No date field
]
start_date = datetime(2025, 1, 1)
filtered = filter_matches_by_period(matches, start_date=start_date)
assert len(filtered) == 1
assert filtered[0]['match_id'] == 1