InvokeAI/tests/app/api/test_sliding_window_token.py
Lincoln Stein 01c67c5468
Fix (multiuser): Ask user to log back in when security token has expired (#9017)
* Initial plan

* Warn user when credentials have expired in multiuser mode

Agent-Logs-Url: https://github.com/lstein/InvokeAI/sessions/f0947cda-b15c-475d-b7f4-2d553bdf2cd6

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* Address code review: avoid multiple localStorage reads in base query

Agent-Logs-Url: https://github.com/lstein/InvokeAI/sessions/f0947cda-b15c-475d-b7f4-2d553bdf2cd6

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* bugfix(multiuser): ask user to log back in when authentication token expires

* feat: sliding window session expiry with token refresh

Backend:
- SlidingWindowTokenMiddleware refreshes JWT on each mutating request
  (POST/PUT/PATCH/DELETE), returning a new token in X-Refreshed-Token
  response header. GET requests don't refresh (they're often background
  fetches that shouldn't reset the inactivity timer).
- CORS expose_headers updated to allow X-Refreshed-Token.

Frontend:
- dynamicBaseQuery picks up X-Refreshed-Token from responses and
  updates localStorage so subsequent requests use the fresh expiry.
- 401 handler only triggers sessionExpiredLogout when a token was
  actually sent (not for unauthenticated background requests).
- ProtectedRoute polls localStorage every 5s and listens for storage
  events to detect token removal (e.g. manual deletion, other tabs).

Result: session expires after TOKEN_EXPIRATION_NORMAL (1 day) of
inactivity, not a fixed time after login. Any user-initiated action
resets the clock.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore(backend): ruff

* fix: address review feedback on auth token handling

Bug fixes:
- ProtectedRoute: only treat 401 errors as session expiry, not
  transient 500/network errors that should not force logout
- Token refresh: use explicit remember_me claim in JWT instead of
  inferring from remaining lifetime, preventing silent downgrade of
  7-day tokens to 1-day when <24h remains
- TokenData: add remember_me field, set during login

Tests (6 new):
- Mutating requests (POST/PUT/DELETE) return X-Refreshed-Token
- GET requests do not return X-Refreshed-Token
- Unauthenticated requests do not return X-Refreshed-Token
- Remember-me token refreshes to 7-day duration even near expiry
- Normal token refreshes to 1-day duration
- remember_me claim preserved through refresh cycle

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore(backend): ruff

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
2026-04-05 23:11:44 -04:00

169 lines
6.5 KiB
Python

"""Tests for SlidingWindowTokenMiddleware and token refresh behavior."""
from datetime import timedelta
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from invokeai.app.services.auth.token_service import TokenData, create_access_token, set_jwt_secret
@pytest.fixture(autouse=True)
def _setup_jwt_secret():
"""Ensure JWT secret is set for all tests."""
set_jwt_secret("test-secret-key-for-sliding-window-tests")
def _create_test_app() -> FastAPI:
"""Create a minimal FastAPI app with the SlidingWindowTokenMiddleware."""
from invokeai.app.api_app import SlidingWindowTokenMiddleware
test_app = FastAPI()
test_app.add_middleware(SlidingWindowTokenMiddleware)
@test_app.get("/test")
async def get_endpoint():
return {"ok": True}
@test_app.post("/test")
async def post_endpoint():
return {"ok": True}
@test_app.put("/test")
async def put_endpoint():
return {"ok": True}
@test_app.delete("/test")
async def delete_endpoint():
return {"ok": True}
return test_app
def _make_token(remember_me: bool = False, expires_delta: timedelta | None = None) -> str:
"""Create a test token."""
token_data = TokenData(
user_id="test-user",
email="test@test.com",
is_admin=False,
remember_me=remember_me,
)
return create_access_token(token_data, expires_delta)
class TestSlidingWindowTokenMiddleware:
"""Tests for SlidingWindowTokenMiddleware."""
def test_mutating_request_returns_refreshed_token(self):
"""Authenticated POST/PUT/PATCH/DELETE requests return X-Refreshed-Token."""
app = _create_test_app()
client = TestClient(app)
token = _make_token()
for method in ["post", "put", "delete"]:
response = getattr(client, method)("/test", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert "X-Refreshed-Token" in response.headers, f"{method.upper()} should return refreshed token"
def test_get_request_does_not_return_refreshed_token(self):
"""Authenticated GET requests do NOT return X-Refreshed-Token."""
app = _create_test_app()
client = TestClient(app)
token = _make_token()
response = client.get("/test", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert "X-Refreshed-Token" not in response.headers
def test_unauthenticated_request_does_not_return_refreshed_token(self):
"""Requests without a token do NOT return X-Refreshed-Token."""
app = _create_test_app()
client = TestClient(app)
response = client.post("/test")
assert response.status_code == 200
assert "X-Refreshed-Token" not in response.headers
def test_remember_me_token_refreshes_to_remember_me_duration(self):
"""A remember_me=True token refreshes with the remember-me duration, not the normal duration."""
from jose import jwt
from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_REMEMBER_ME
from invokeai.app.services.auth.token_service import ALGORITHM, get_jwt_secret
app = _create_test_app()
client = TestClient(app)
# Create a remember-me token with only 1 hour remaining (less than 24h)
token = _make_token(remember_me=True, expires_delta=timedelta(hours=1))
response = client.post("/test", headers={"Authorization": f"Bearer {token}"})
assert "X-Refreshed-Token" in response.headers
# Decode the refreshed token and check its expiry
refreshed_token = response.headers["X-Refreshed-Token"]
payload = jwt.decode(refreshed_token, get_jwt_secret(), algorithms=[ALGORITHM])
# The refreshed token should have ~7 days of remaining life, not ~1 day
from datetime import datetime, timezone
remaining_seconds = payload["exp"] - datetime.now(timezone.utc).timestamp()
remaining_days = remaining_seconds / 86400
# Should be close to TOKEN_EXPIRATION_REMEMBER_ME (7 days), not TOKEN_EXPIRATION_NORMAL (1 day)
assert remaining_days > TOKEN_EXPIRATION_REMEMBER_ME - 0.1, (
f"Remember-me token was downgraded: {remaining_days:.1f} days remaining, "
f"expected ~{TOKEN_EXPIRATION_REMEMBER_ME}"
)
def test_normal_token_refreshes_to_normal_duration(self):
"""A remember_me=False token refreshes with the normal duration."""
from jose import jwt
from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_NORMAL
from invokeai.app.services.auth.token_service import ALGORITHM, get_jwt_secret
app = _create_test_app()
client = TestClient(app)
token = _make_token(remember_me=False)
response = client.post("/test", headers={"Authorization": f"Bearer {token}"})
refreshed_token = response.headers["X-Refreshed-Token"]
payload = jwt.decode(refreshed_token, get_jwt_secret(), algorithms=[ALGORITHM])
from datetime import datetime, timezone
remaining_seconds = payload["exp"] - datetime.now(timezone.utc).timestamp()
remaining_days = remaining_seconds / 86400
# Should be close to TOKEN_EXPIRATION_NORMAL (1 day), not TOKEN_EXPIRATION_REMEMBER_ME (7 days)
assert remaining_days < TOKEN_EXPIRATION_NORMAL + 0.1, (
f"Normal token got remember-me duration: {remaining_days:.1f} days"
)
assert remaining_days > TOKEN_EXPIRATION_NORMAL - 0.1, (
f"Normal token duration too short: {remaining_days:.1f} days"
)
def test_remember_me_claim_preserved_in_refreshed_token(self):
"""The remember_me claim is preserved when a token is refreshed."""
from invokeai.app.services.auth.token_service import verify_token
app = _create_test_app()
client = TestClient(app)
# Test with remember_me=True
token = _make_token(remember_me=True)
response = client.post("/test", headers={"Authorization": f"Bearer {token}"})
refreshed_data = verify_token(response.headers["X-Refreshed-Token"])
assert refreshed_data is not None
assert refreshed_data.remember_me is True
# Test with remember_me=False
token = _make_token(remember_me=False)
response = client.post("/test", headers={"Authorization": f"Bearer {token}"})
refreshed_data = verify_token(response.headers["X-Refreshed-Token"])
assert refreshed_data is not None
assert refreshed_data.remember_me is False