""" Pytest configuration and fixtures for backend tests. """ import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from app.database import Base, get_db from app.models import ( User, Tweet, RedditPost, RedditComment, SentimentScore, EnergyScore ) # Create in-memory SQLite database for testing SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:" engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @pytest.fixture(scope="function") def db_session(): """ Create a fresh database session for each test. This fixture creates a new database, creates all tables, and returns a session. The database is cleaned up after the test. """ # Create all tables Base.metadata.create_all(bind=engine) # Create session session = TestingSessionLocal() try: yield session finally: session.close() # Drop all tables after test Base.metadata.drop_all(bind=engine) @pytest.fixture(scope="function") def client(db_session): """ Create a test client with a database session override. This fixture overrides the get_db dependency to use the test database. """ from fastapi.testclient import TestClient from app.main import app def override_get_db(): try: yield db_session finally: pass app.dependency_overrides[get_db] = override_get_db with TestClient(app) as test_client: yield test_client app.dependency_overrides.clear() @pytest.fixture(scope="function") def sample_user(db_session): """ Create a sample user for testing. Returns: User object """ user = User( username="testuser", email="test@example.com", hashed_password="hashed_password", is_premium=False ) db_session.add(user) db_session.commit() db_session.refresh(user) return user @pytest.fixture(scope="function") def sample_tweet(db_session, sample_user): """ Create a sample tweet for testing. Returns: Tweet object """ from datetime import datetime tweet = Tweet( tweet_id="1234567890", text="This is a test tweet", author="test_author", created_at=datetime.utcnow(), user_id=sample_user.id, retweets=10, likes=20, replies=5 ) db_session.add(tweet) db_session.commit() db_session.refresh(tweet) return tweet @pytest.fixture(scope="function") def sample_reddit_post(db_session, sample_user): """ Create a sample Reddit post for testing. Returns: RedditPost object """ from datetime import datetime reddit_post = RedditPost( post_id="reddit123", title="Test Reddit Post", text="This is a test Reddit post", author="reddit_author", subreddit="test_subreddit", created_at=datetime.utcnow(), user_id=sample_user.id, upvotes=15, comments=3 ) db_session.add(reddit_post) db_session.commit() db_session.refresh(reddit_post) return reddit_post @pytest.fixture(scope="function") def sample_sentiment_scores(db_session, sample_tweet, sample_reddit_post): """ Create sample sentiment scores for testing. Returns: List of SentimentScore objects """ from datetime import datetime twitter_sentiment = SentimentScore( entity_id=sample_tweet.tweet_id, entity_type="tweet", score=0.5, sentiment_type="positive", positive=0.6, negative=0.2, neutral=0.2, created_at=datetime.utcnow() ) reddit_sentiment = SentimentScore( entity_id=sample_reddit_post.post_id, entity_type="reddit_post", score=0.3, sentiment_type="positive", positive=0.4, negative=0.3, neutral=0.3, created_at=datetime.utcnow() ) db_session.add(twitter_sentiment) db_session.add(reddit_sentiment) db_session.commit() db_session.refresh(twitter_sentiment) db_session.refresh(reddit_sentiment) return [twitter_sentiment, reddit_sentiment]