from datetime import datetime, timedelta from unittest.mock import patch import stripe from django.test import TestCase from django.urls import reverse from main import models from main.views import billing class BillingCannotChangeIsPremiumTestCase(TestCase): """Test user cannot change their is_premium flag without going through billing.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.client.force_login(self.user) def test_update_billing_settings(self): data = { "username": "alice", "is_premium": True, } self.client.post(reverse("user_update"), data) self.assertFalse(models.User.objects.get(id=self.user.id).is_premium) class BillingIndexGrandfatherTestCase(TestCase): """Test billing pages work accordingly for grandathered user.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.user.is_grandfathered = True self.user.save() self.client.force_login(self.user) def test_index(self): response = self.client.get(reverse("billing_index")) self.assertEqual(response.status_code, 200) self.assertContains(response, b"Grandfather Plan") def test_cannot_subscribe(self): response = self.client.post(reverse("billing_subscription")) self.assertEqual(response.status_code, 302) self.assertRedirects(response, reverse("dashboard")) def test_cannot_cancel_get(self): response = self.client.get(reverse("billing_subscription_cancel")) self.assertEqual(response.status_code, 302) self.assertRedirects(response, reverse("dashboard")) class BillingIndexFreeTestCase(TestCase): """Test billing index works for free user.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.user.save() self.client.force_login(self.user) def test_index(self): with ( patch.object( stripe.Customer, "create", return_value={"id": "cus_123abcdefg"} ), patch.object(billing, "_get_stripe_subscription", return_value=None), patch.object( billing, "_get_payment_methods", ), patch.object(billing, "_get_invoices"), ): response = self.client.get(reverse("billing_index")) self.assertEqual(response.status_code, 200) self.assertContains(response, b"Free Plan") class BillingIndexPremiumTestCase(TestCase): """Test billing index works for premium user.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.user.is_premium = True self.user.save() self.client.force_login(self.user) def test_index(self): one_year_later = datetime.now() + timedelta(days=365) subscription = { "current_period_end": one_year_later.timestamp(), "current_period_start": datetime.now().timestamp(), } with ( patch.object( stripe.Customer, "create", return_value={"id": "cus_123abcdefg"} ), patch.object( billing, "_get_stripe_subscription", return_value=subscription, ), patch.object(billing, "_get_payment_methods"), patch.object(billing, "_get_invoices"), ): response = self.client.get(reverse("billing_index")) self.assertEqual(response.status_code, 200) self.assertContains(response, b"Premium Plan") class BillingCardAddTestCase(TestCase): """Test billing card add functionality.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.user.is_premium = True self.user.save() self.client.force_login(self.user) def test_card_add_get(self): with patch.object( stripe.SetupIntent, "create", return_value={"client_secret": "seti_123abc"} ): response = self.client.get(reverse("billing_card")) self.assertEqual(response.status_code, 200) self.assertContains(response, b"Add card") def test_card_add_post(self): one_year_later = datetime.now() + timedelta(days=365) subscription = { "current_period_end": one_year_later.timestamp(), "current_period_start": datetime.now().timestamp(), } with ( patch.object( stripe.Customer, "create", return_value={"id": "cus_123abcdefg"} ), patch.object( billing, "_get_stripe_subscription", return_value=subscription, ), patch.object(billing, "_get_payment_methods"), patch.object(billing, "_get_invoices"), ): response = self.client.post( reverse("billing_card"), data={"card_token": "tok_123"}, follow=True, ) self.assertEqual(response.status_code, 200) self.assertContains(response, b"Premium Plan") class BillingCancelSubscriptionTestCase(TestCase): """Test billing cancel subscription.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.user.is_premium = True self.user.stripe_customer_id = "cus_123abcdefg" self.user.save() self.client.force_login(self.user) def test_cancel_subscription_get(self): one_year_later = datetime.now() + timedelta(days=365) subscription = { "current_period_end": one_year_later.timestamp(), "current_period_start": datetime.now().timestamp(), } with patch.object( billing, "_get_stripe_subscription", return_value=subscription, ): response = self.client.get(reverse("billing_subscription_cancel")) self.assertEqual(response.status_code, 200) self.assertContains(response, b"Cancel Premium") def test_cancel_subscription_post(self): with ( patch.object(stripe.Subscription, "delete"), patch.object( billing, "_get_stripe_subscription", return_value={"id": "sub_123"}, ), ): response = self.client.post(reverse("billing_subscription_cancel")) self.assertEqual(response.status_code, 302) self.assertFalse(models.User.objects.get(id=self.user.id).is_premium) class BillingCancelSubscriptionTwiceTestCase(TestCase): """Test billing cancel subscription when already canceled.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.user.stripe_customer_id = "cus_123abcdefg" self.user.save() self.client.force_login(self.user) def test_cancel_subscription_get(self): with ( patch.object(billing, "_get_stripe_subscription", return_value=None), patch.object( stripe.Customer, "create", return_value={"id": "cus_123abcdefg"} ), patch.object( billing, "_get_payment_methods", ), patch.object(billing, "_get_invoices"), ): response = self.client.get(reverse("billing_subscription_cancel")) # need to check inside with context because billing_index needs # __get_stripe_subscription patch self.assertRedirects(response, reverse("billing_index")) def test_cancel_subscription_post(self): with ( patch.object(stripe.Subscription, "delete"), patch.object( billing, "_get_stripe_subscription", return_value=None, ), patch.object( stripe.Customer, "create", return_value={"id": "cus_123abcdefg"} ), patch.object( billing, "_get_payment_methods", ), patch.object(billing, "_get_invoices"), ): response = self.client.post(reverse("billing_subscription_cancel")) self.assertRedirects(response, reverse("billing_index")) self.assertFalse(models.User.objects.get(id=self.user.id).is_premium) class BillingReenableSubscriptionTestCase(TestCase): """Test re-enabling subscription after cancelation.""" def setUp(self): self.user = models.User.objects.create(username="alice") self.user.stripe_customer_id = "cus_123abcdefg" self.user.save() self.client.force_login(self.user) def test_reenable_subscription_post(self): one_year_later = datetime.now() + timedelta(days=365) subscription = { "current_period_end": one_year_later.timestamp(), "current_period_start": datetime.now().timestamp(), } created_subscription = { "id": "sub_456abcdefg", "latest_invoice": { "payment_intent": { "client_secret": "seti_123abc", }, }, } with ( patch.object(stripe.Subscription, "delete"), patch.object( billing, "_get_stripe_subscription", return_value=subscription, ), patch.object( stripe.Customer, "create", return_value={"id": "cus_123abcdefg"} ), patch.object( stripe.Subscription, "create", return_value=created_subscription, ), patch.object( billing, "_get_payment_methods", ), patch.object(billing, "_get_invoices"), ): response = self.client.post(reverse("billing_subscription")) self.assertRedirects(response, reverse("billing_index")) self.assertTrue(models.User.objects.get(id=self.user.id).is_premium)