| |
| """ |
| Centralised OAuth 2.0 helper for GitHub, Google Drive and Slack. |
| |
| Usage |
| ----- |
| from auth import oauth_manager |
| |
| # 1) Redirect user to consent page |
| auth_url, state = oauth_manager.get_authorization_url( |
| provider="github", |
| redirect_uri="https://your‑app.com/callback" |
| ) |
| |
| # 2) In your callback handler, exchange the code for a token |
| token = oauth_manager.fetch_token( |
| provider="github", |
| redirect_uri="https://your‑app.com/callback", |
| authorization_response=request.url # full URL with ?code=... |
| ) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from dataclasses import dataclass |
| from typing import Dict, Optional, Tuple |
|
|
| from authlib.integrations.requests_client import OAuth2Session |
|
|
|
|
| |
| |
| |
| @dataclass(frozen=True) |
| class ProviderConfig: |
| client_id: str | None |
| client_secret: str | None |
| authorize_url: str |
| token_url: str |
| scope: str |
|
|
|
|
| def _env(name: str) -> str | None: |
| """Shorthand for os.getenv with strip().""" |
| val = os.getenv(name) |
| return val.strip() if val else None |
|
|
|
|
| PROVIDERS: Dict[str, ProviderConfig] = { |
| "github": ProviderConfig( |
| client_id=_env("GITHUB_CLIENT_ID"), |
| client_secret=_env("GITHUB_CLIENT_SECRET"), |
| authorize_url="https://github.com/login/oauth/authorize", |
| token_url="https://github.com/login/oauth/access_token", |
| scope="repo read:org", |
| ), |
| "google": ProviderConfig( |
| client_id=_env("GOOGLE_CLIENT_ID"), |
| client_secret=_env("GOOGLE_CLIENT_SECRET"), |
| authorize_url="https://accounts.google.com/o/oauth2/auth", |
| token_url="https://oauth2.googleapis.com/token", |
| scope="openid email profile https://www.googleapis.com/auth/drive.readonly", |
| ), |
| "slack": ProviderConfig( |
| client_id=_env("SLACK_CLIENT_ID"), |
| client_secret=_env("SLACK_CLIENT_SECRET"), |
| authorize_url="https://slack.com/oauth/v2/authorize", |
| token_url="https://slack.com/api/oauth.v2.access", |
| scope="channels:read chat:write", |
| ), |
| } |
|
|
|
|
| |
| |
| |
| class OAuthManager: |
| """Tiny wrapper around Authlib’s OAuth2Session per provider.""" |
|
|
| def __init__(self, providers: Dict[str, ProviderConfig]): |
| self.providers = providers |
|
|
| |
| def _session(self, provider: str, redirect_uri: str) -> OAuth2Session: |
| cfg = self.providers.get(provider) |
| if not cfg: |
| raise KeyError(f"Unsupported provider '{provider}'.") |
| if not (cfg.client_id and cfg.client_secret): |
| raise RuntimeError( |
| f"OAuth credentials for '{provider}' are missing. " |
| "Set the *_CLIENT_ID and *_CLIENT_SECRET env‑vars." |
| ) |
| return OAuth2Session( |
| client_id=cfg.client_id, |
| client_secret=cfg.client_secret, |
| scope=cfg.scope, |
| redirect_uri=redirect_uri, |
| ) |
|
|
| |
| def get_authorization_url( |
| self, provider: str, redirect_uri: str, state: Optional[str] = None |
| ) -> Tuple[str, str]: |
| """ |
| Return (auth_url, state) for the given provider. |
| |
| Pass the *state* back into `fetch_token` to mitigate CSRF. |
| """ |
| sess = self._session(provider, redirect_uri) |
| cfg = self.providers[provider] |
| auth_url, final_state = sess.create_authorization_url( |
| cfg.authorize_url, state=state |
| ) |
| return auth_url, final_state |
|
|
| def fetch_token( |
| self, provider: str, redirect_uri: str, authorization_response: str |
| ) -> Dict: |
| """ |
| Exchange ?code=… for an access token. |
| |
| Returns the token dict from Authlib (includes access_token, |
| refresh_token, expires_in, etc.). |
| """ |
| sess = self._session(provider, redirect_uri) |
| cfg = self.providers[provider] |
| return sess.fetch_token( |
| cfg.token_url, |
| authorization_response=authorization_response, |
| client_secret=cfg.client_secret, |
| ) |
|
|
|
|
| |
| oauth_manager = OAuthManager(PROVIDERS) |
|
|