diff --git a/EDMCLogging.py b/EDMCLogging.py new file mode 100644 index 0000000..e15f645 --- /dev/null +++ b/EDMCLogging.py @@ -0,0 +1,447 @@ +""" +Set up required logging for the application. + +This module provides for a common logging-powered log facility. +Mostly it implements a logging.Filter() in order to get two extra +members on the logging.LogRecord instance for use in logging.Formatter() +strings. + +If type checking, e.g. mypy, objects to `logging.trace(...)` then include this +stanza: + + # See EDMCLogging.py docs. + # isort: off + if TYPE_CHECKING: + from logging import trace, TRACE # type: ignore # noqa: F401 + # isort: on + +This is needed because we add the TRACE level and the trace() function +ourselves at runtime. + +To utilise logging in core code, or internal plugins, include this: + + from EDMCLogging import get_main_logger + + logger = get_main_logger() + +To utilise logging in a 'found' (third-party) plugin, include this: + + import os + import logging + + plugin_name = os.path.basename(os.path.dirname(__file__)) + # plugin_name here *must* be the name of the folder the plugin resides in + # See, plug.py:load_plugins() + logger = logging.getLogger(f'{appname}.{plugin_name}') +""" + +import inspect +import logging +import logging.handlers +from contextlib import suppress +from fnmatch import fnmatch +# So that any warning about accessing a protected member is only in one place. +from sys import _getframe as getframe +from threading import get_native_id as thread_native_id +from traceback import print_exc +from typing import TYPE_CHECKING, Tuple, cast + +import config + +# TODO: Tests: +# +# 1. Call from bare function in file. +# 2. Call from `if __name__ == "__main__":` section +# +# 3. Call from 1st level function in 1st level Class in file +# 4. Call from 2nd level function in 1st level Class in file +# 5. Call from 3rd level function in 1st level Class in file +# +# 6. Call from 1st level function in 2nd level Class in file +# 7. Call from 2nd level function in 2nd level Class in file +# 8. Call from 3rd level function in 2nd level Class in file +# +# 9. Call from 1st level function in 3rd level Class in file +# 10. Call from 2nd level function in 3rd level Class in file +# 11. Call from 3rd level function in 3rd level Class in file +# +# 12. Call from 2nd level file, all as above. +# +# 13. Call from *module* +# +# 14. Call from *package* + +_default_loglevel = logging.DEBUG + +# Define a TRACE level +LEVEL_TRACE = 5 +LEVEL_TRACE_ALL = 3 +logging.addLevelName(LEVEL_TRACE, "TRACE") +logging.addLevelName(LEVEL_TRACE_ALL, "TRACE_ALL") +logging.TRACE = LEVEL_TRACE # type: ignore +logging.TRACE_ALL = LEVEL_TRACE_ALL # type: ignore +logging.Logger.trace = lambda self, message, *args, **kwargs: self._log( # type: ignore + logging.TRACE, # type: ignore + message, + args, + **kwargs +) + + +def _trace_if(self: logging.Logger, condition: str, message: str, *args, **kwargs) -> None: + if any(fnmatch(condition, p) for p in []): + self._log(logging.TRACE, message, args, **kwargs) # type: ignore # we added it + return + + self._log(logging.TRACE_ALL, message, args, **kwargs) # type: ignore # we added it + + +logging.Logger.trace_if = _trace_if # type: ignore + +# we cant hide this from `from xxx` imports and I'd really rather no-one other than `logging` had access to it +del _trace_if + +if TYPE_CHECKING: + from types import FrameType + + # Fake type that we can use here to tell type checkers that trace exists + + class LoggerMixin(logging.Logger): + """LoggerMixin is a fake class that tells type checkers that trace exists on a given type.""" + + def trace(self, message, *args, **kwargs) -> None: + """See implementation above.""" + ... + + def trace_if(self, condition: str, message, *args, **kwargs) -> None: + """ + Fake trace if method, traces only if condition exists in trace_on. + + See implementation above. + """ + ... + + +class Logger: + """ + Wrapper class for all logging configuration and code. + + Class instantiation requires the 'logger name' and optional loglevel. + It is intended that this 'logger name' be re-used in all files/modules + that need to log. + + Users of this class should then call getLogger() to get the + logging.Logger instance. + """ + + def __init__(self, logger_name: str, loglevel: int = _default_loglevel): + """ + Set up a `logging.Logger` with our preferred configuration. + + This includes using an EDMCContextFilter to add 'class' and 'qualname' + expansions for logging.Formatter(). + """ + self.logger = logging.getLogger(logger_name) + # Configure the logging.Logger + # This needs to always be TRACE in order to let TRACE level messages + # through to check the *handler* levels. + self.logger.setLevel(logging.TRACE) # type: ignore + + # Set up filter for adding class name + self.logger_filter = EDMCContextFilter() + self.logger.addFilter(self.logger_filter) + + # Our basic channel handling stdout + self.logger_channel = logging.StreamHandler() + # This should be affected by the user configured log level + self.logger_channel.setLevel(loglevel) + + self.logger_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(process)d:%(thread)d:%(osthreadid)d %(module)s.%(qualname)s:%(lineno)d: %(message)s') # noqa: E501 + self.logger_formatter.default_time_format = '%Y-%m-%d %H:%M:%S' + self.logger_formatter.default_msec_format = '%s.%03d' + + self.logger_channel.setFormatter(self.logger_formatter) + self.logger.addHandler(self.logger_channel) + + def get_logger(self) -> 'LoggerMixin': + """ + Obtain the self.logger of the class instance. + + Not to be confused with logging.getLogger(). + """ + return cast('LoggerMixin', self.logger) + + def get_streamhandler(self) -> logging.Handler: + """ + Obtain the self.logger_channel StreamHandler instance. + + :return: logging.StreamHandler + """ + return self.logger_channel + + def set_channels_loglevel(self, level: int) -> None: + """ + Set the specified log level on the channels. + + :param level: A valid `logging` level. + :return: None + """ + self.logger_channel.setLevel(level) + + def set_console_loglevel(self, level: int) -> None: + """ + Set the specified log level on the console channel. + + :param level: A valid `logging` level. + :return: None + """ + if self.logger_channel.level != logging.TRACE: # type: ignore + self.logger_channel.setLevel(level) + else: + logger.trace("Not changing log level because it's TRACE") # type: ignore + + +class EDMCContextFilter(logging.Filter): + """ + Implements filtering to add extra format specifiers, and tweak others. + + logging.Filter sub-class to place extra attributes of the calling site + into the record. + """ + + def filter(self, record: logging.LogRecord) -> bool: + """ + Attempt to set/change fields in the LogRecord. + + 1. class = class name(s) of the call site, if applicable + 2. qualname = __qualname__ of the call site. This simplifies + logging.Formatter() as you can use just this no matter if there is + a class involved or not, so you get a nice clean: + .[.classB....]. + 3. osthreadid = OS level thread ID. + + If we fail to be able to properly set either then: + + 1. Use print() to alert, to be SURE a message is seen. + 2. But also return strings noting the error, so there'll be + something in the log output if it happens. + + :param record: The LogRecord we're "filtering" + :return: bool - Always true in order for this record to be logged. + """ + (class_name, qualname, module_name) = self.caller_attributes(module_name=getattr(record, 'module')) + + # Only set if we got a useful value + if module_name: + setattr(record, 'module', module_name) + + # Only set if not already provided by logging itself + if getattr(record, 'class', None) is None: + setattr(record, 'class', class_name) + + # Only set if not already provided by logging itself + if getattr(record, 'qualname', None) is None: + setattr(record, 'qualname', qualname) + + setattr(record, 'osthreadid', thread_native_id()) + + return True + + @classmethod + def caller_attributes(cls, module_name: str = '') -> Tuple[str, str, str]: # noqa: CCR001, E501, C901 # this is as refactored as is sensible + """ + Determine extra or changed fields for the caller. + + 1. qualname finds the relevant object and its __qualname__ + 2. caller_class_names is just the full class names of the calling + class if relevant. + 3. module is munged if we detect the caller is an EDMC plugin, + whether internal or found. + + :param module_name: The name of the calling module. + :return: Tuple[str, str, str] - class_name, qualname, module_name + """ + frame = cls.find_caller_frame() + + caller_qualname = caller_class_names = '' + if frame: + # + try: + frame_info = inspect.getframeinfo(frame) + # raise(IndexError) # TODO: Remove, only for testing + + except Exception: + # Separate from the print below to guarantee we see at least this much. + print('EDMCLogging:EDMCContextFilter:caller_attributes(): Failed in `inspect.getframinfo(frame)`') + + # We want to *attempt* to show something about the nature of 'frame', + # but at this point we can't trust it will work. + try: + print(f'frame: {frame}') + + except Exception: + pass + + # We've given up, so just return '??' to signal we couldn't get the info + return '??', '??', module_name + try: + args, _, _, value_dict = inspect.getargvalues(frame) + if len(args) and args[0] in ('self', 'cls'): + frame_class: 'object' = value_dict[args[0]] + + if frame_class: + # See https://en.wikipedia.org/wiki/Name_mangling#Python for how name mangling works. + # For more detail, see _Py_Mangle in CPython's Python/compile.c. + name = frame_info.function + class_name = frame_class.__class__.__name__.lstrip("_") + if name.startswith("__") and not name.endswith("__") and class_name: + name = f'_{class_name}{frame_info.function}' + + # Find __qualname__ of the caller + fn = inspect.getattr_static(frame_class, name, None) + if fn is None: + # For some reason getattr_static cant grab this. Try and grab it with getattr, bail out + # if we get a RecursionError indicating a property + try: + fn = getattr(frame_class, name, None) + except RecursionError: + print( + "EDMCLogging:EDMCContextFilter:caller_attributes():" + "Failed to get attribute for function info. Bailing out" + ) + # class_name is better than nothing for __qualname__ + return class_name, class_name, module_name + + if fn is not None: + if isinstance(fn, property): + class_name = str(frame_class) + # If somehow you make your __class__ or __class__.__qualname__ recursive, + # I'll be impressed. + if hasattr(frame_class, '__class__') and hasattr(frame_class.__class__, "__qualname__"): + class_name = frame_class.__class__.__qualname__ + caller_qualname = f"{class_name}.{name}(property)" + + else: + caller_qualname = f"" + + elif not hasattr(fn, '__qualname__'): + caller_qualname = name + + elif hasattr(fn, '__qualname__') and fn.__qualname__: + caller_qualname = fn.__qualname__ + + # Find containing class name(s) of caller, if any + if ( + frame_class.__class__ and hasattr(frame_class.__class__, '__qualname__') + and frame_class.__class__.__qualname__ + ): + caller_class_names = frame_class.__class__.__qualname__ + + # It's a call from the top level module file + elif frame_info.function == '': + caller_class_names = '' + caller_qualname = value_dict['__name__'] + + elif frame_info.function != '': + caller_class_names = '' + caller_qualname = frame_info.function + + module_name = cls.munge_module_name(frame_info, module_name) + + except Exception as e: + print('ALERT! Something went VERY wrong in handling finding info to log') + print('ALERT! Information is as follows') + with suppress(Exception): + + print(f'ALERT! {e=}') + print_exc() + print(f'ALERT! {frame=}') + with suppress(Exception): + print(f'ALERT! {fn=}') # type: ignore + with suppress(Exception): + print(f'ALERT! {cls=}') + + finally: # Ensure this always happens + # https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack + del frame + + if caller_qualname == '': + print('ALERT! Something went wrong with finding caller qualname for logging!') + caller_qualname = '' + + if caller_class_names == '': + print('ALERT! Something went wrong with finding caller class name(s) for logging!') + caller_class_names = '' + + return caller_class_names, caller_qualname, module_name + + @classmethod + def find_caller_frame(cls): + """ + Find the stack frame of the logging caller. + + :returns: 'frame' object such as from sys._getframe() + """ + # Go up through stack frames until we find the first with a + # type(f_locals.self) of logging.Logger. This should be the start + # of the frames internal to logging. + frame: 'FrameType' = getframe(0) + while frame: + if isinstance(frame.f_locals.get('self'), logging.Logger): + frame = cast('FrameType', frame.f_back) # Want to start on the next frame below + break + frame = cast('FrameType', frame.f_back) + # Now continue up through frames until we find the next one where + # that is *not* true, as it should be the call site of the logger + # call + while frame: + if not isinstance(frame.f_locals.get('self'), logging.Logger): + break # We've found the frame we want + frame = cast('FrameType', frame.f_back) + return frame + + @classmethod + def munge_module_name(cls, frame_info: inspect.Traceback, module_name: str) -> str: + """ + Adjust module_name based on the file path for the given frame. + + We want to distinguish between other code and both our internal plugins + and the 'found' ones. + + For internal plugins we want "plugins.". + For 'found' plugins we want "....". + + :param frame_info: The frame_info of the caller. + :param module_name: The module_name string to munge. + :return: The munged module_name. + """ + # file_name = pathlib.Path(frame_info.filename).expanduser() + # plugin_dir = pathlib.Path(config.plugin_dir_path).expanduser() + # internal_plugin_dir = pathlib.Path(config.internal_plugin_dir_path).expanduser() + # if internal_plugin_dir in file_name.parents: + # # its an internal plugin + # return f'plugins.{".".join(file_name.relative_to(internal_plugin_dir).parent.parts)}' + + # elif plugin_dir in file_name.parents: + # return f'.{".".join(file_name.relative_to(plugin_dir).parent.parts)}' + + return module_name + + +def get_main_logger(sublogger_name: str = '') -> 'LoggerMixin': + """Return the correct logger for how the program is being run. (outdated)""" + # if not os.getenv("EDMC_NO_UI"): + # # GUI app being run + # return cast('LoggerMixin', logging.getLogger(appname)) + # else: + # # Must be the CLI + # return cast('LoggerMixin', logging.getLogger(appcmdname)) + return cast('LoggerMixin', logging.getLogger(__name__)) + +# Singleton +loglevel = logging._nameToLevel.get(config.log_level, logging.DEBUG) # noqa: + +base_logger_name = __name__ + +edmclogger = Logger(base_logger_name, loglevel=loglevel) +logger: 'LoggerMixin' = edmclogger.get_logger() diff --git a/capi/__init__.py b/capi/__init__.py new file mode 100644 index 0000000..bf90cd2 --- /dev/null +++ b/capi/__init__.py @@ -0,0 +1,167 @@ +from . import model +from . import utils +from . import exceptions + +import base64 +import os +import time + +import config +from EDMCLogging import get_main_logger + +logger = get_main_logger() + + +class CAPIAuthorizer: + def __init__(self, _model: model.Model): + self.model: model.Model = _model + + def auth_init(self) -> str: + """ + Generates initial url for fdev auth + + :return: + """ + + code_verifier = base64.urlsafe_b64encode(os.urandom(32)) + code_challenge = utils.generate_challenge(code_verifier) + state_string = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8") + + self.model.auth_init(code_verifier.decode('utf-8'), state_string) + + redirect_user_to_fdev = f"{config.AUTH_URL}?" \ + f"audience=all&" \ + f"scope=capi&" \ + f"response_type=code&" \ + f"client_id={config.CLIENT_ID}&" \ + f"code_challenge={code_challenge}&" \ + f"code_challenge_method=S256&" \ + f"state={state_string}&" \ + f"redirect_uri={config.REDIRECT_URL}" + + return redirect_user_to_fdev + + def fdev_callback(self, code: str, state: str) -> dict[str, str]: + if type(code) is not str or type(state) != str: + raise TypeError('code and state must be strings') + + code_verifier = self.model.get_verifier(state) + + self.model.set_code(code, state) + + token_request = utils.get_tokens_request(code, code_verifier) + + try: + token_request.raise_for_status() + + except Exception as e: + logger.exception( + f'token_request failed. Status: {token_request.status_code!r}, text: {token_request.text!r}', + exc_info=e + ) + + # self.model.delete_row(state) + + raise e + + tokens = token_request.json() + + access_token = tokens["access_token"] + refresh_token = tokens["refresh_token"] + expires_in = tokens["expires_in"] + timestamp_got_expires_in = int(time.time()) + + self.model.set_tokens(access_token, refresh_token, expires_in, timestamp_got_expires_in, state) + + try: + nickname = utils.get_nickname(access_token) + + except Exception as e: + logger.warning(f"Couldn't get nickname for state: {state!r}", exc_info=e) + raise KeyError(f"Couldn't get nickname for state: {state!r}") + + msg = {'status': 'ok', 'description': '', 'state': ''} + + if not self.model.set_nickname(nickname, state): + msg['description'] = 'Tokens updated' + + else: + msg['description'] = 'Tokens saved' + + msg['state'] = self.model.get_state_by_nickname(nickname) + + return msg + + def get_token_by_state(self, state: str) -> dict: + self.refresh_by_state(state) + row = self.model.get_token_for_user(state) + row['expires_over'] = int(row['expires_on']) - int(time.time()) + return row + + def refresh_by_state(self, state: str, force_refresh=False, failure_tolerance=True) -> dict: + """ + + :param state: + :param force_refresh: if we should update token when its time hasn't come + :param failure_tolerance: if we shouldn't remove row in case of update's failure + :return: + """ + + msg = {'status': '', 'description': '', 'state': ''} + row = self.model.get_row(state) + + if row is None: + # No such state in DB + msg['status'] = 'error' + msg['message'] = 'No such state in DB' + raise exceptions.RefreshFail(msg['description'], msg['status'], state) + + msg['state'] = state + + if int(time.time()) < int(row['timestamp_got_expires_in']) + int(row['expires_in']) and not force_refresh: + msg['status'] = 'ok' + msg['description'] = "Didn't refresh since it isn't required" + + return msg # token isn't expired and we don't force updating + + try: + refresh_request = utils.refresh_request(row['refresh_token']) + if refresh_request.status_code == 418: # Server's maintenance + logger.warning(f'FDEV maintenance 418, text: {refresh_request.text!r}') + msg['status'] = 'error' + msg['description'] = 'FDEV on maintenance' + raise exceptions.RefreshFail(msg['message'], msg['status'], state) + + refresh_request.raise_for_status() + + tokens = refresh_request.json() + tokens['timestamp_got_expires_in'] = int(time.time()) + self.model.set_tokens(*tokens, state) + msg['status'] = 'ok' + msg['description'] = 'Token were successfully updated' + return msg + + except Exception as e: + # probably here something don't work + logger.warning(f'Fail on refreshing token for {state}, row:{row}') + msg['status'] = 'error' + + self.model.increment_refresh_tries(state) + + if not failure_tolerance: + if row['refresh_tries'] >= 5: # hardcoded limit + msg['description'] = "Refresh failed. You were removed from DB due to refresh rate limitings" + self.model.delete_row(state) + raise exceptions.RefreshFail(msg['description'], msg['status'], state) + + msg['description'] = "Refresh failed. Try later" + raise exceptions.RefreshFail(msg['description'], msg['status'], state) + + def delete_by_state(self, state: str) -> None: + self.model.delete_row(state) + + def list_all_users(self) -> list[dict]: + return self.model.list_all_records() + + +capi_authorizer = CAPIAuthorizer(model.Model()) diff --git a/capi/exceptions.py b/capi/exceptions.py new file mode 100644 index 0000000..565ad9a --- /dev/null +++ b/capi/exceptions.py @@ -0,0 +1,6 @@ +class RefreshFail(Exception): + def __init__(self, message: str, status: str, state: str): + self.message = message + self.status = status + self.state = state + super().__init__(self.message + ' for ' + self.state) diff --git a/capi/model.py b/capi/model.py new file mode 100644 index 0000000..fbf65bb --- /dev/null +++ b/capi/model.py @@ -0,0 +1,117 @@ +import sqlite3 +from typing import Union + +from . import sqlite_requests +from EDMCLogging import get_main_logger + +logger = get_main_logger() + + +class Model: + def __init__(self): + self.db: sqlite3.Connection = sqlite3.connect('companion-api.sqlite', check_same_thread=False) + self.db.row_factory = lambda c, r: dict(zip([col[0] for col in c.description], r)) + with self.db: + self.db.execute(sqlite_requests.schema) + + def auth_init(self, verifier: str, state: str) -> None: + with self.db: + self.db.execute( + sqlite_requests.insert_auth_init, + { + 'code_verifier': verifier, + 'state': state + }) + + def get_verifier(self, state: str) -> str: + code_verifier_req = self.db.execute(sqlite_requests.select_all_by_state, {'state': state}).fetchone() + + if code_verifier_req is None: + # Somebody got here not by frontier redirect + raise KeyError('No state in DB found') + + code_verifier: str = code_verifier_req['code_verifier'] + + return code_verifier + + def set_code(self, code: str, state: str) -> None: + with self.db: + self.db.execute(sqlite_requests.set_code_state, {'code': code, 'state': state}) + + def delete_row(self, state: str) -> None: + with self.db: + self.db.execute(sqlite_requests.delete_by_state, {'state': state}) + + def set_tokens(self, access_token: str, refresh_token: str, expires_in: int, timestamp_got_expires_in: int, + state: str) -> None: + + with self.db: + self.db.execute( + sqlite_requests.set_tokens_by_state, + { + 'access_token': access_token, + 'refresh_token': refresh_token, + 'expires_in': expires_in, + 'timestamp_got_expires_in': timestamp_got_expires_in, + 'state': state + }) + + def set_nickname(self, nickname: str, state: str) -> bool: + """ + Return True if inserted successfully, False if catch sqlite3.IntegrityError + + :param nickname: + :param state: + :return: + """ + + try: + with self.db: + self.db.execute( + sqlite_requests.set_nickname_by_state, + {'nickname': nickname, 'state': state} + ) + return True + + except sqlite3.IntegrityError: + """ + let's migrate new received data to old state + 1. Get old state by nickname + 2. Remove row with old state + 3. Set old state where new state + """ + state_to_set: str = self.get_state_by_nickname(nickname) + self.delete_row(state_to_set) + self.set_new_state(state_to_set, state) + with self.db: + self.db.execute( + sqlite_requests.set_nickname_by_state, + {'nickname': nickname, 'state': state_to_set} + ) + return False + + def get_state_by_nickname(self, nickname: str) -> Union[None, str]: + with self.db: + nickname_f1 = self.db.execute(sqlite_requests.get_state_by_nickname, {'nickname': nickname}).fetchone() + + if nickname_f1 is None: + return None + + return nickname_f1['state'] + + def set_new_state(self, new_state: str, old_state: str) -> None: + with self.db: + self.db.execute(sqlite_requests.update_state_by_state, {'new_state': new_state, 'state': old_state}) + + def get_row(self, state: str) -> dict: + return self.db.execute(sqlite_requests.select_all_by_state, {'state': state}).fetchone() + + def increment_refresh_tries(self, state: str) -> None: + with self.db: + self.db.execute(sqlite_requests.refresh_times_increment, {'state': state}) + + def get_token_for_user(self, state: str) -> dict: + return self.db.execute(sqlite_requests.get_token_for_user, {'state': state}).fetchone() + + def list_all_records(self) -> list: + return self.db.execute(sqlite_requests.select_nickname_state_all).fetchall() diff --git a/capi/sqlite_requests.py b/capi/sqlite_requests.py new file mode 100644 index 0000000..ac40e8f --- /dev/null +++ b/capi/sqlite_requests.py @@ -0,0 +1,48 @@ +schema = """create table if not exists authorizations ( + code_verifier text , + state text, + timestamp_init datetime default current_timestamp , + code text, + access_token text, + refresh_token text, + expires_in text, + timestamp_got_expires_in text, + nickname text unique, + refresh_tries int default 0 +);""" + +insert_auth_init = """insert into authorizations + (code_verifier, state) + values + (:code_verifier, :state);""" + +select_all_by_state = """select * from authorizations where state = :state;""" + +set_code_state = """update authorizations set code = :code where state = :state;""" + +delete_by_state = """delete from authorizations where state = :state;""" + +set_tokens_by_state = """update authorizations +set + access_token = :access_token, + refresh_token = :refresh_token, + expires_in = :expires_in, + timestamp_got_expires_in = :timestamp_got_expires_in, + refresh_tries = 0 +where state = :state;""" + +set_nickname_by_state = """update authorizations set nickname = :nickname where state = :state;""" + +get_state_by_nickname = """select state from authorizations where nickname = :nickname;""" + +update_state_by_state = """update authorizations set state = :new_state where state = :state;""" + +refresh_times_increment = """update authorizations set refresh_tries = refresh_tries + 1 where state = :state;""" + +get_token_for_user = """select + access_token, + timestamp_got_expires_in + expires_in as expires_on, + nickname +from authorizations where state = :state;""" + +select_nickname_state_all = """select nickname, state from authorizations where nickname is not null;""" diff --git a/capi/utils.py b/capi/utils.py new file mode 100644 index 0000000..cf94c98 --- /dev/null +++ b/capi/utils.py @@ -0,0 +1,50 @@ +import hashlib +import base64 +import requests +import config + + +def get_tokens_request(code, code_verifier) -> requests.Response: + """ + Performs initial requesting access and refresh tokens + + :param code: + :param code_verifier: + :return: + """ + token_request: requests.Response = requests.post( + url=config.TOKEN_URL, + headers={ + 'Content-Type': 'application/x-www-form-urlencoded', + 'User-Agent': config.PROPER_USER_AGENT + }, + data=f'redirect_uri={config.REDIRECT_URL}&' + f'code={code}&' + f'grant_type=authorization_code&' + f'code_verifier={code_verifier}&' + f'client_id={config.CLIENT_ID}') + + return token_request + + +def get_nickname(access_token: str) -> str: + return requests.get( + url='https://companion.orerve.net/profile', + headers={'Authorization': f'Bearer {access_token}', + 'User-Agent': config.PROPER_USER_AGENT}).json()["commander"]["name"] + + +def refresh_request(refresh_token: str) -> requests.Response: + return requests.post( + url=config.TOKEN_URL, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + data=f'grant_type=refresh_token&client_id={config.CLIENT_ID}&refresh_token={refresh_token}') + + +def generate_challenge(_verifier: bytes) -> str: + """ + It takes code verifier and return code challenge + :param _verifier: + :return: + """ + return base64.urlsafe_b64encode(hashlib.sha256(_verifier).digest())[:-1].decode('utf-8') diff --git a/config.py b/config.py index 64c578c..b932e15 100644 --- a/config.py +++ b/config.py @@ -4,10 +4,12 @@ from os import getenv CLIENT_ID = getenv('client_id') assert CLIENT_ID, "No client_id in env" +log_level = 'DEBUG' + REDIRECT_URL = requests.utils.quote("http://127.0.0.1:9000/fdev-redirect") AUTH_URL = 'https://auth.frontierstore.net/auth' TOKEN_URL = 'https://auth.frontierstore.net/token' -PROPER_USER_AGENT = 'EDCD-a31-0.1' +PROPER_USER_AGENT = 'EDCD-a31-0.2' REDIRECT_HTML_TEMPLATE = """ diff --git a/legacy/config.py b/legacy/config.py new file mode 100644 index 0000000..64c578c --- /dev/null +++ b/legacy/config.py @@ -0,0 +1,36 @@ +import requests +from os import getenv + +CLIENT_ID = getenv('client_id') +assert CLIENT_ID, "No client_id in env" + +REDIRECT_URL = requests.utils.quote("http://127.0.0.1:9000/fdev-redirect") +AUTH_URL = 'https://auth.frontierstore.net/auth' +TOKEN_URL = 'https://auth.frontierstore.net/token' +PROPER_USER_AGENT = 'EDCD-a31-0.1' +REDIRECT_HTML_TEMPLATE = """ + + + + + + + + You should be redirected shortly... + + + +""" + +ADMIN_USERS_TEMPLATE = """ + + + + + + {} + + +""" + +ADMIN_USER_TEMPLATE = '{desc}' diff --git a/main.py b/legacy/main.py similarity index 96% rename from main.py rename to legacy/main.py index 556aac6..00f5987 100644 --- a/main.py +++ b/legacy/main.py @@ -37,7 +37,7 @@ logger.addHandler(stdout_handler) Typical workflow: 1. User open Authorize endpoint 2. Authorize building link for FDEV's /auth, sending link to user -3. Authorize write code_verifier, state, timestamp_init to DB +3. Authorize write code_verifier, state, timestamp_init (by DBMS) to DB 4. User approving client, redirecting to FDEV_redirect endpoint 5. Searching in DB if we have record with this state if don't have: diff --git a/refresher.py b/legacy/refresher.py similarity index 100% rename from refresher.py rename to legacy/refresher.py diff --git a/web.py b/web.py new file mode 100644 index 0000000..0b816cc --- /dev/null +++ b/web.py @@ -0,0 +1,29 @@ +import falcon +import waitress +import json + +from capi import capi_authorizer +import config + + +class AuthInit: + def on_get(self, req: falcon.request.Request, resp: falcon.response.Response) -> None: + resp.content_type = falcon.MEDIA_HTML + resp.text = config.REDIRECT_HTML_TEMPLATE.format(link=capi_authorizer.auth_init()) + + +class FDEVCallback: + def on_get(self, req: falcon.request.Request, resp: falcon.response.Response) -> None: + code = req.get_param('code') + state = req.get_param('state') + msg = capi_authorizer.fdev_callback(code, state) + resp.content_type = falcon.MEDIA_JSON + resp.text = json.dumps(msg) + + +application = falcon.App() +application.add_route('/authorize', AuthInit()) +application.add_route('/fdev-redirect', FDEVCallback()) + +if __name__ == '__main__': + waitress.serve(application, host='127.0.0.1', port=9000) \ No newline at end of file