Tidy app entrypoint (#7668)

## Summary

Prior to this PR, most of the app setup was being done in `api_app.py`
at import time. This PR cleans this up, by:
- Splitting app setup into more modular functions
- Narrower responsibility for the `api_app.py` file - it just
initializes the `FastAPI` app

The main motivation for this changes is to make it easier to support an
upcoming torch configuration feature that requires more careful ordering
of app initialization steps.

## Related Issues / Discussions

N/A

## QA Instructions

- [x] Launch the app via invokeai-web.py and smoke test it.
- [ ] Launch the app via the installer and smoke test it.
- [x] Test that generate_openapi_schema.py produces the same result
before and after the change.
- [x] No regression in unit tests that directly interact with the app.
(test_images.py)

## Merge Plan

- [x] Check to see if there are any commercial implications to modifying
the app entrypoint.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
Ryan Dick 2025-02-28 16:07:30 -05:00 committed by GitHub
commit 26730ca702
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 133 additions and 109 deletions

View File

@ -1,12 +1,8 @@
import asyncio
import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from pathlib import Path
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@ -15,11 +11,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
@ -36,39 +28,15 @@ from invokeai.app.api.routers import (
workflows,
)
from invokeai.app.api.sockets import SocketIO
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
app_config = get_config()
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
logger = InvokeAILogger.get_logger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
# the correct port when the server starts in the lifespan handler.
port = app_config.port
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -77,7 +45,7 @@ async def lifespan(app: FastAPI):
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
proto = "https" if app_config.ssl_certfile else "http"
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
# Logging this way ignores the logger's log level and _always_ logs the message
record = logger.makeRecord(
@ -192,73 +160,3 @@ except RuntimeError:
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def invoke_api() -> None:
def find_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_port(port=port + 1)
else:
return port
if app_config.dev_reload:
try:
import jurigged
except ImportError as e:
logger.error(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
global port
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
check_cudnn(logger)
config = uvicorn.Config(
app=app,
host=app_config.host,
port=port,
loop="asyncio",
log_level=app_config.log_level_network,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

View File

@ -1,12 +1,74 @@
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
import uvicorn
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.startup_utils import (
apply_monkeypatches,
check_cudnn,
enable_dev_reload,
find_open_port,
register_mime_types,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
def get_app():
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
"""
from invokeai.app.api_app import app, loop
return app, loop
def run_app() -> None:
# Before doing _anything_, parse CLI args!
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
"""The main entrypoint for the app."""
# Parse the CLI arguments.
InvokeAIArgs.parse_args()
from invokeai.app.api_app import invoke_api
# Load config.
app_config = get_config()
invoke_api()
logger = InvokeAILogger.get_logger(config=app_config)
# Find an open port, and modify the config accordingly.
orig_config_port = app_config.port
app_config.port = find_open_port(app_config.port)
if orig_config_port != app_config.port:
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
# Miscellaneous startup tasks.
apply_monkeypatches()
register_mime_types()
if app_config.dev_reload:
enable_dev_reload()
check_cudnn(logger)
# Initialize the app and event loop.
app, loop = get_app()
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)
# Start the server.
config = uvicorn.Config(
app=app,
host=app_config.host,
port=app_config.port,
loop="asyncio",
log_level=app_config.log_level_network,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)
loop.run_until_complete(server.serve())

View File

@ -0,0 +1,64 @@
import logging
import mimetypes
import socket
import torch
def find_open_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_open_port(port=port + 1)
else:
return port
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def enable_dev_reload() -> None:
"""Enable hot reloading on python file changes during development."""
from invokeai.backend.util.logging import InvokeAILogger
try:
import jurigged
except ImportError as e:
raise RuntimeError(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
) from e
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
def apply_monkeypatches() -> None:
"""Apply monkeypatches to fix issues with third-party libraries."""
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
def register_mime_types() -> None:
"""Register additional mime types for windows."""
# Fix for windows mimetypes registry entries being borked.
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")