mirror of
https://github.com/norohind/FDEV-CAPI-Handler.git
synced 2025-04-12 05:50:00 +03:00
WIP: full refactor
This commit is contained in:
parent
7a51aaeb67
commit
bbeec2e3fe
447
EDMCLogging.py
Normal file
447
EDMCLogging.py
Normal file
@ -0,0 +1,447 @@
|
||||
"""
|
||||
Set up required logging for the application.
|
||||
|
||||
This module provides for a common logging-powered log facility.
|
||||
Mostly it implements a logging.Filter() in order to get two extra
|
||||
members on the logging.LogRecord instance for use in logging.Formatter()
|
||||
strings.
|
||||
|
||||
If type checking, e.g. mypy, objects to `logging.trace(...)` then include this
|
||||
stanza:
|
||||
|
||||
# See EDMCLogging.py docs.
|
||||
# isort: off
|
||||
if TYPE_CHECKING:
|
||||
from logging import trace, TRACE # type: ignore # noqa: F401
|
||||
# isort: on
|
||||
|
||||
This is needed because we add the TRACE level and the trace() function
|
||||
ourselves at runtime.
|
||||
|
||||
To utilise logging in core code, or internal plugins, include this:
|
||||
|
||||
from EDMCLogging import get_main_logger
|
||||
|
||||
logger = get_main_logger()
|
||||
|
||||
To utilise logging in a 'found' (third-party) plugin, include this:
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
plugin_name = os.path.basename(os.path.dirname(__file__))
|
||||
# plugin_name here *must* be the name of the folder the plugin resides in
|
||||
# See, plug.py:load_plugins()
|
||||
logger = logging.getLogger(f'{appname}.{plugin_name}')
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import logging.handlers
|
||||
from contextlib import suppress
|
||||
from fnmatch import fnmatch
|
||||
# So that any warning about accessing a protected member is only in one place.
|
||||
from sys import _getframe as getframe
|
||||
from threading import get_native_id as thread_native_id
|
||||
from traceback import print_exc
|
||||
from typing import TYPE_CHECKING, Tuple, cast
|
||||
|
||||
import config
|
||||
|
||||
# TODO: Tests:
|
||||
#
|
||||
# 1. Call from bare function in file.
|
||||
# 2. Call from `if __name__ == "__main__":` section
|
||||
#
|
||||
# 3. Call from 1st level function in 1st level Class in file
|
||||
# 4. Call from 2nd level function in 1st level Class in file
|
||||
# 5. Call from 3rd level function in 1st level Class in file
|
||||
#
|
||||
# 6. Call from 1st level function in 2nd level Class in file
|
||||
# 7. Call from 2nd level function in 2nd level Class in file
|
||||
# 8. Call from 3rd level function in 2nd level Class in file
|
||||
#
|
||||
# 9. Call from 1st level function in 3rd level Class in file
|
||||
# 10. Call from 2nd level function in 3rd level Class in file
|
||||
# 11. Call from 3rd level function in 3rd level Class in file
|
||||
#
|
||||
# 12. Call from 2nd level file, all as above.
|
||||
#
|
||||
# 13. Call from *module*
|
||||
#
|
||||
# 14. Call from *package*
|
||||
|
||||
_default_loglevel = logging.DEBUG
|
||||
|
||||
# Define a TRACE level
|
||||
LEVEL_TRACE = 5
|
||||
LEVEL_TRACE_ALL = 3
|
||||
logging.addLevelName(LEVEL_TRACE, "TRACE")
|
||||
logging.addLevelName(LEVEL_TRACE_ALL, "TRACE_ALL")
|
||||
logging.TRACE = LEVEL_TRACE # type: ignore
|
||||
logging.TRACE_ALL = LEVEL_TRACE_ALL # type: ignore
|
||||
logging.Logger.trace = lambda self, message, *args, **kwargs: self._log( # type: ignore
|
||||
logging.TRACE, # type: ignore
|
||||
message,
|
||||
args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def _trace_if(self: logging.Logger, condition: str, message: str, *args, **kwargs) -> None:
|
||||
if any(fnmatch(condition, p) for p in []):
|
||||
self._log(logging.TRACE, message, args, **kwargs) # type: ignore # we added it
|
||||
return
|
||||
|
||||
self._log(logging.TRACE_ALL, message, args, **kwargs) # type: ignore # we added it
|
||||
|
||||
|
||||
logging.Logger.trace_if = _trace_if # type: ignore
|
||||
|
||||
# we cant hide this from `from xxx` imports and I'd really rather no-one other than `logging` had access to it
|
||||
del _trace_if
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
# Fake type that we can use here to tell type checkers that trace exists
|
||||
|
||||
class LoggerMixin(logging.Logger):
|
||||
"""LoggerMixin is a fake class that tells type checkers that trace exists on a given type."""
|
||||
|
||||
def trace(self, message, *args, **kwargs) -> None:
|
||||
"""See implementation above."""
|
||||
...
|
||||
|
||||
def trace_if(self, condition: str, message, *args, **kwargs) -> None:
|
||||
"""
|
||||
Fake trace if method, traces only if condition exists in trace_on.
|
||||
|
||||
See implementation above.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
Wrapper class for all logging configuration and code.
|
||||
|
||||
Class instantiation requires the 'logger name' and optional loglevel.
|
||||
It is intended that this 'logger name' be re-used in all files/modules
|
||||
that need to log.
|
||||
|
||||
Users of this class should then call getLogger() to get the
|
||||
logging.Logger instance.
|
||||
"""
|
||||
|
||||
def __init__(self, logger_name: str, loglevel: int = _default_loglevel):
|
||||
"""
|
||||
Set up a `logging.Logger` with our preferred configuration.
|
||||
|
||||
This includes using an EDMCContextFilter to add 'class' and 'qualname'
|
||||
expansions for logging.Formatter().
|
||||
"""
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
# Configure the logging.Logger
|
||||
# This needs to always be TRACE in order to let TRACE level messages
|
||||
# through to check the *handler* levels.
|
||||
self.logger.setLevel(logging.TRACE) # type: ignore
|
||||
|
||||
# Set up filter for adding class name
|
||||
self.logger_filter = EDMCContextFilter()
|
||||
self.logger.addFilter(self.logger_filter)
|
||||
|
||||
# Our basic channel handling stdout
|
||||
self.logger_channel = logging.StreamHandler()
|
||||
# This should be affected by the user configured log level
|
||||
self.logger_channel.setLevel(loglevel)
|
||||
|
||||
self.logger_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(process)d:%(thread)d:%(osthreadid)d %(module)s.%(qualname)s:%(lineno)d: %(message)s') # noqa: E501
|
||||
self.logger_formatter.default_time_format = '%Y-%m-%d %H:%M:%S'
|
||||
self.logger_formatter.default_msec_format = '%s.%03d'
|
||||
|
||||
self.logger_channel.setFormatter(self.logger_formatter)
|
||||
self.logger.addHandler(self.logger_channel)
|
||||
|
||||
def get_logger(self) -> 'LoggerMixin':
|
||||
"""
|
||||
Obtain the self.logger of the class instance.
|
||||
|
||||
Not to be confused with logging.getLogger().
|
||||
"""
|
||||
return cast('LoggerMixin', self.logger)
|
||||
|
||||
def get_streamhandler(self) -> logging.Handler:
|
||||
"""
|
||||
Obtain the self.logger_channel StreamHandler instance.
|
||||
|
||||
:return: logging.StreamHandler
|
||||
"""
|
||||
return self.logger_channel
|
||||
|
||||
def set_channels_loglevel(self, level: int) -> None:
|
||||
"""
|
||||
Set the specified log level on the channels.
|
||||
|
||||
:param level: A valid `logging` level.
|
||||
:return: None
|
||||
"""
|
||||
self.logger_channel.setLevel(level)
|
||||
|
||||
def set_console_loglevel(self, level: int) -> None:
|
||||
"""
|
||||
Set the specified log level on the console channel.
|
||||
|
||||
:param level: A valid `logging` level.
|
||||
:return: None
|
||||
"""
|
||||
if self.logger_channel.level != logging.TRACE: # type: ignore
|
||||
self.logger_channel.setLevel(level)
|
||||
else:
|
||||
logger.trace("Not changing log level because it's TRACE") # type: ignore
|
||||
|
||||
|
||||
class EDMCContextFilter(logging.Filter):
|
||||
"""
|
||||
Implements filtering to add extra format specifiers, and tweak others.
|
||||
|
||||
logging.Filter sub-class to place extra attributes of the calling site
|
||||
into the record.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""
|
||||
Attempt to set/change fields in the LogRecord.
|
||||
|
||||
1. class = class name(s) of the call site, if applicable
|
||||
2. qualname = __qualname__ of the call site. This simplifies
|
||||
logging.Formatter() as you can use just this no matter if there is
|
||||
a class involved or not, so you get a nice clean:
|
||||
<file/module>.<classA>[.classB....].<function>
|
||||
3. osthreadid = OS level thread ID.
|
||||
|
||||
If we fail to be able to properly set either then:
|
||||
|
||||
1. Use print() to alert, to be SURE a message is seen.
|
||||
2. But also return strings noting the error, so there'll be
|
||||
something in the log output if it happens.
|
||||
|
||||
:param record: The LogRecord we're "filtering"
|
||||
:return: bool - Always true in order for this record to be logged.
|
||||
"""
|
||||
(class_name, qualname, module_name) = self.caller_attributes(module_name=getattr(record, 'module'))
|
||||
|
||||
# Only set if we got a useful value
|
||||
if module_name:
|
||||
setattr(record, 'module', module_name)
|
||||
|
||||
# Only set if not already provided by logging itself
|
||||
if getattr(record, 'class', None) is None:
|
||||
setattr(record, 'class', class_name)
|
||||
|
||||
# Only set if not already provided by logging itself
|
||||
if getattr(record, 'qualname', None) is None:
|
||||
setattr(record, 'qualname', qualname)
|
||||
|
||||
setattr(record, 'osthreadid', thread_native_id())
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def caller_attributes(cls, module_name: str = '') -> Tuple[str, str, str]: # noqa: CCR001, E501, C901 # this is as refactored as is sensible
|
||||
"""
|
||||
Determine extra or changed fields for the caller.
|
||||
|
||||
1. qualname finds the relevant object and its __qualname__
|
||||
2. caller_class_names is just the full class names of the calling
|
||||
class if relevant.
|
||||
3. module is munged if we detect the caller is an EDMC plugin,
|
||||
whether internal or found.
|
||||
|
||||
:param module_name: The name of the calling module.
|
||||
:return: Tuple[str, str, str] - class_name, qualname, module_name
|
||||
"""
|
||||
frame = cls.find_caller_frame()
|
||||
|
||||
caller_qualname = caller_class_names = ''
|
||||
if frame:
|
||||
# <https://stackoverflow.com/questions/2203424/python-how-to-retrieve-class-information-from-a-frame-object#2220759>
|
||||
try:
|
||||
frame_info = inspect.getframeinfo(frame)
|
||||
# raise(IndexError) # TODO: Remove, only for testing
|
||||
|
||||
except Exception:
|
||||
# Separate from the print below to guarantee we see at least this much.
|
||||
print('EDMCLogging:EDMCContextFilter:caller_attributes(): Failed in `inspect.getframinfo(frame)`')
|
||||
|
||||
# We want to *attempt* to show something about the nature of 'frame',
|
||||
# but at this point we can't trust it will work.
|
||||
try:
|
||||
print(f'frame: {frame}')
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# We've given up, so just return '??' to signal we couldn't get the info
|
||||
return '??', '??', module_name
|
||||
try:
|
||||
args, _, _, value_dict = inspect.getargvalues(frame)
|
||||
if len(args) and args[0] in ('self', 'cls'):
|
||||
frame_class: 'object' = value_dict[args[0]]
|
||||
|
||||
if frame_class:
|
||||
# See https://en.wikipedia.org/wiki/Name_mangling#Python for how name mangling works.
|
||||
# For more detail, see _Py_Mangle in CPython's Python/compile.c.
|
||||
name = frame_info.function
|
||||
class_name = frame_class.__class__.__name__.lstrip("_")
|
||||
if name.startswith("__") and not name.endswith("__") and class_name:
|
||||
name = f'_{class_name}{frame_info.function}'
|
||||
|
||||
# Find __qualname__ of the caller
|
||||
fn = inspect.getattr_static(frame_class, name, None)
|
||||
if fn is None:
|
||||
# For some reason getattr_static cant grab this. Try and grab it with getattr, bail out
|
||||
# if we get a RecursionError indicating a property
|
||||
try:
|
||||
fn = getattr(frame_class, name, None)
|
||||
except RecursionError:
|
||||
print(
|
||||
"EDMCLogging:EDMCContextFilter:caller_attributes():"
|
||||
"Failed to get attribute for function info. Bailing out"
|
||||
)
|
||||
# class_name is better than nothing for __qualname__
|
||||
return class_name, class_name, module_name
|
||||
|
||||
if fn is not None:
|
||||
if isinstance(fn, property):
|
||||
class_name = str(frame_class)
|
||||
# If somehow you make your __class__ or __class__.__qualname__ recursive,
|
||||
# I'll be impressed.
|
||||
if hasattr(frame_class, '__class__') and hasattr(frame_class.__class__, "__qualname__"):
|
||||
class_name = frame_class.__class__.__qualname__
|
||||
caller_qualname = f"{class_name}.{name}(property)"
|
||||
|
||||
else:
|
||||
caller_qualname = f"<property {name} on {class_name}>"
|
||||
|
||||
elif not hasattr(fn, '__qualname__'):
|
||||
caller_qualname = name
|
||||
|
||||
elif hasattr(fn, '__qualname__') and fn.__qualname__:
|
||||
caller_qualname = fn.__qualname__
|
||||
|
||||
# Find containing class name(s) of caller, if any
|
||||
if (
|
||||
frame_class.__class__ and hasattr(frame_class.__class__, '__qualname__')
|
||||
and frame_class.__class__.__qualname__
|
||||
):
|
||||
caller_class_names = frame_class.__class__.__qualname__
|
||||
|
||||
# It's a call from the top level module file
|
||||
elif frame_info.function == '<module>':
|
||||
caller_class_names = '<none>'
|
||||
caller_qualname = value_dict['__name__']
|
||||
|
||||
elif frame_info.function != '':
|
||||
caller_class_names = '<none>'
|
||||
caller_qualname = frame_info.function
|
||||
|
||||
module_name = cls.munge_module_name(frame_info, module_name)
|
||||
|
||||
except Exception as e:
|
||||
print('ALERT! Something went VERY wrong in handling finding info to log')
|
||||
print('ALERT! Information is as follows')
|
||||
with suppress(Exception):
|
||||
|
||||
print(f'ALERT! {e=}')
|
||||
print_exc()
|
||||
print(f'ALERT! {frame=}')
|
||||
with suppress(Exception):
|
||||
print(f'ALERT! {fn=}') # type: ignore
|
||||
with suppress(Exception):
|
||||
print(f'ALERT! {cls=}')
|
||||
|
||||
finally: # Ensure this always happens
|
||||
# https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack
|
||||
del frame
|
||||
|
||||
if caller_qualname == '':
|
||||
print('ALERT! Something went wrong with finding caller qualname for logging!')
|
||||
caller_qualname = '<ERROR in EDMCLogging.caller_class_and_qualname() for "qualname">'
|
||||
|
||||
if caller_class_names == '':
|
||||
print('ALERT! Something went wrong with finding caller class name(s) for logging!')
|
||||
caller_class_names = '<ERROR in EDMCLogging.caller_class_and_qualname() for "class">'
|
||||
|
||||
return caller_class_names, caller_qualname, module_name
|
||||
|
||||
@classmethod
|
||||
def find_caller_frame(cls):
|
||||
"""
|
||||
Find the stack frame of the logging caller.
|
||||
|
||||
:returns: 'frame' object such as from sys._getframe()
|
||||
"""
|
||||
# Go up through stack frames until we find the first with a
|
||||
# type(f_locals.self) of logging.Logger. This should be the start
|
||||
# of the frames internal to logging.
|
||||
frame: 'FrameType' = getframe(0)
|
||||
while frame:
|
||||
if isinstance(frame.f_locals.get('self'), logging.Logger):
|
||||
frame = cast('FrameType', frame.f_back) # Want to start on the next frame below
|
||||
break
|
||||
frame = cast('FrameType', frame.f_back)
|
||||
# Now continue up through frames until we find the next one where
|
||||
# that is *not* true, as it should be the call site of the logger
|
||||
# call
|
||||
while frame:
|
||||
if not isinstance(frame.f_locals.get('self'), logging.Logger):
|
||||
break # We've found the frame we want
|
||||
frame = cast('FrameType', frame.f_back)
|
||||
return frame
|
||||
|
||||
@classmethod
|
||||
def munge_module_name(cls, frame_info: inspect.Traceback, module_name: str) -> str:
|
||||
"""
|
||||
Adjust module_name based on the file path for the given frame.
|
||||
|
||||
We want to distinguish between other code and both our internal plugins
|
||||
and the 'found' ones.
|
||||
|
||||
For internal plugins we want "plugins.<filename>".
|
||||
For 'found' plugins we want "<plugins>.<plugin_name>...".
|
||||
|
||||
:param frame_info: The frame_info of the caller.
|
||||
:param module_name: The module_name string to munge.
|
||||
:return: The munged module_name.
|
||||
"""
|
||||
# file_name = pathlib.Path(frame_info.filename).expanduser()
|
||||
# plugin_dir = pathlib.Path(config.plugin_dir_path).expanduser()
|
||||
# internal_plugin_dir = pathlib.Path(config.internal_plugin_dir_path).expanduser()
|
||||
# if internal_plugin_dir in file_name.parents:
|
||||
# # its an internal plugin
|
||||
# return f'plugins.{".".join(file_name.relative_to(internal_plugin_dir).parent.parts)}'
|
||||
|
||||
# elif plugin_dir in file_name.parents:
|
||||
# return f'<plugin>.{".".join(file_name.relative_to(plugin_dir).parent.parts)}'
|
||||
|
||||
return module_name
|
||||
|
||||
|
||||
def get_main_logger(sublogger_name: str = '') -> 'LoggerMixin':
|
||||
"""Return the correct logger for how the program is being run. (outdated)"""
|
||||
# if not os.getenv("EDMC_NO_UI"):
|
||||
# # GUI app being run
|
||||
# return cast('LoggerMixin', logging.getLogger(appname))
|
||||
# else:
|
||||
# # Must be the CLI
|
||||
# return cast('LoggerMixin', logging.getLogger(appcmdname))
|
||||
return cast('LoggerMixin', logging.getLogger(__name__))
|
||||
|
||||
# Singleton
|
||||
loglevel = logging._nameToLevel.get(config.log_level, logging.DEBUG) # noqa:
|
||||
|
||||
base_logger_name = __name__
|
||||
|
||||
edmclogger = Logger(base_logger_name, loglevel=loglevel)
|
||||
logger: 'LoggerMixin' = edmclogger.get_logger()
|
167
capi/__init__.py
Normal file
167
capi/__init__.py
Normal file
@ -0,0 +1,167 @@
|
||||
from . import model
|
||||
from . import utils
|
||||
from . import exceptions
|
||||
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
|
||||
import config
|
||||
from EDMCLogging import get_main_logger
|
||||
|
||||
logger = get_main_logger()
|
||||
|
||||
|
||||
class CAPIAuthorizer:
|
||||
def __init__(self, _model: model.Model):
|
||||
self.model: model.Model = _model
|
||||
|
||||
def auth_init(self) -> str:
|
||||
"""
|
||||
Generates initial url for fdev auth
|
||||
|
||||
:return:
|
||||
"""
|
||||
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32))
|
||||
code_challenge = utils.generate_challenge(code_verifier)
|
||||
state_string = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8")
|
||||
|
||||
self.model.auth_init(code_verifier.decode('utf-8'), state_string)
|
||||
|
||||
redirect_user_to_fdev = f"{config.AUTH_URL}?" \
|
||||
f"audience=all&" \
|
||||
f"scope=capi&" \
|
||||
f"response_type=code&" \
|
||||
f"client_id={config.CLIENT_ID}&" \
|
||||
f"code_challenge={code_challenge}&" \
|
||||
f"code_challenge_method=S256&" \
|
||||
f"state={state_string}&" \
|
||||
f"redirect_uri={config.REDIRECT_URL}"
|
||||
|
||||
return redirect_user_to_fdev
|
||||
|
||||
def fdev_callback(self, code: str, state: str) -> dict[str, str]:
|
||||
if type(code) is not str or type(state) != str:
|
||||
raise TypeError('code and state must be strings')
|
||||
|
||||
code_verifier = self.model.get_verifier(state)
|
||||
|
||||
self.model.set_code(code, state)
|
||||
|
||||
token_request = utils.get_tokens_request(code, code_verifier)
|
||||
|
||||
try:
|
||||
token_request.raise_for_status()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'token_request failed. Status: {token_request.status_code!r}, text: {token_request.text!r}',
|
||||
exc_info=e
|
||||
)
|
||||
|
||||
# self.model.delete_row(state)
|
||||
|
||||
raise e
|
||||
|
||||
tokens = token_request.json()
|
||||
|
||||
access_token = tokens["access_token"]
|
||||
refresh_token = tokens["refresh_token"]
|
||||
expires_in = tokens["expires_in"]
|
||||
timestamp_got_expires_in = int(time.time())
|
||||
|
||||
self.model.set_tokens(access_token, refresh_token, expires_in, timestamp_got_expires_in, state)
|
||||
|
||||
try:
|
||||
nickname = utils.get_nickname(access_token)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Couldn't get nickname for state: {state!r}", exc_info=e)
|
||||
raise KeyError(f"Couldn't get nickname for state: {state!r}")
|
||||
|
||||
msg = {'status': 'ok', 'description': '', 'state': ''}
|
||||
|
||||
if not self.model.set_nickname(nickname, state):
|
||||
msg['description'] = 'Tokens updated'
|
||||
|
||||
else:
|
||||
msg['description'] = 'Tokens saved'
|
||||
|
||||
msg['state'] = self.model.get_state_by_nickname(nickname)
|
||||
|
||||
return msg
|
||||
|
||||
def get_token_by_state(self, state: str) -> dict:
|
||||
self.refresh_by_state(state)
|
||||
row = self.model.get_token_for_user(state)
|
||||
row['expires_over'] = int(row['expires_on']) - int(time.time())
|
||||
return row
|
||||
|
||||
def refresh_by_state(self, state: str, force_refresh=False, failure_tolerance=True) -> dict:
|
||||
"""
|
||||
|
||||
:param state:
|
||||
:param force_refresh: if we should update token when its time hasn't come
|
||||
:param failure_tolerance: if we shouldn't remove row in case of update's failure
|
||||
:return:
|
||||
"""
|
||||
|
||||
msg = {'status': '', 'description': '', 'state': ''}
|
||||
row = self.model.get_row(state)
|
||||
|
||||
if row is None:
|
||||
# No such state in DB
|
||||
msg['status'] = 'error'
|
||||
msg['message'] = 'No such state in DB'
|
||||
raise exceptions.RefreshFail(msg['description'], msg['status'], state)
|
||||
|
||||
msg['state'] = state
|
||||
|
||||
if int(time.time()) < int(row['timestamp_got_expires_in']) + int(row['expires_in']) and not force_refresh:
|
||||
msg['status'] = 'ok'
|
||||
msg['description'] = "Didn't refresh since it isn't required"
|
||||
|
||||
return msg # token isn't expired and we don't force updating
|
||||
|
||||
try:
|
||||
refresh_request = utils.refresh_request(row['refresh_token'])
|
||||
if refresh_request.status_code == 418: # Server's maintenance
|
||||
logger.warning(f'FDEV maintenance 418, text: {refresh_request.text!r}')
|
||||
msg['status'] = 'error'
|
||||
msg['description'] = 'FDEV on maintenance'
|
||||
raise exceptions.RefreshFail(msg['message'], msg['status'], state)
|
||||
|
||||
refresh_request.raise_for_status()
|
||||
|
||||
tokens = refresh_request.json()
|
||||
tokens['timestamp_got_expires_in'] = int(time.time())
|
||||
self.model.set_tokens(*tokens, state)
|
||||
msg['status'] = 'ok'
|
||||
msg['description'] = 'Token were successfully updated'
|
||||
return msg
|
||||
|
||||
except Exception as e:
|
||||
# probably here something don't work
|
||||
logger.warning(f'Fail on refreshing token for {state}, row:{row}')
|
||||
msg['status'] = 'error'
|
||||
|
||||
self.model.increment_refresh_tries(state)
|
||||
|
||||
if not failure_tolerance:
|
||||
if row['refresh_tries'] >= 5: # hardcoded limit
|
||||
msg['description'] = "Refresh failed. You were removed from DB due to refresh rate limitings"
|
||||
self.model.delete_row(state)
|
||||
raise exceptions.RefreshFail(msg['description'], msg['status'], state)
|
||||
|
||||
msg['description'] = "Refresh failed. Try later"
|
||||
raise exceptions.RefreshFail(msg['description'], msg['status'], state)
|
||||
|
||||
def delete_by_state(self, state: str) -> None:
|
||||
self.model.delete_row(state)
|
||||
|
||||
def list_all_users(self) -> list[dict]:
|
||||
return self.model.list_all_records()
|
||||
|
||||
|
||||
capi_authorizer = CAPIAuthorizer(model.Model())
|
6
capi/exceptions.py
Normal file
6
capi/exceptions.py
Normal file
@ -0,0 +1,6 @@
|
||||
class RefreshFail(Exception):
|
||||
def __init__(self, message: str, status: str, state: str):
|
||||
self.message = message
|
||||
self.status = status
|
||||
self.state = state
|
||||
super().__init__(self.message + ' for ' + self.state)
|
117
capi/model.py
Normal file
117
capi/model.py
Normal file
@ -0,0 +1,117 @@
|
||||
import sqlite3
|
||||
from typing import Union
|
||||
|
||||
from . import sqlite_requests
|
||||
from EDMCLogging import get_main_logger
|
||||
|
||||
logger = get_main_logger()
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.db: sqlite3.Connection = sqlite3.connect('companion-api.sqlite', check_same_thread=False)
|
||||
self.db.row_factory = lambda c, r: dict(zip([col[0] for col in c.description], r))
|
||||
with self.db:
|
||||
self.db.execute(sqlite_requests.schema)
|
||||
|
||||
def auth_init(self, verifier: str, state: str) -> None:
|
||||
with self.db:
|
||||
self.db.execute(
|
||||
sqlite_requests.insert_auth_init,
|
||||
{
|
||||
'code_verifier': verifier,
|
||||
'state': state
|
||||
})
|
||||
|
||||
def get_verifier(self, state: str) -> str:
|
||||
code_verifier_req = self.db.execute(sqlite_requests.select_all_by_state, {'state': state}).fetchone()
|
||||
|
||||
if code_verifier_req is None:
|
||||
# Somebody got here not by frontier redirect
|
||||
raise KeyError('No state in DB found')
|
||||
|
||||
code_verifier: str = code_verifier_req['code_verifier']
|
||||
|
||||
return code_verifier
|
||||
|
||||
def set_code(self, code: str, state: str) -> None:
|
||||
with self.db:
|
||||
self.db.execute(sqlite_requests.set_code_state, {'code': code, 'state': state})
|
||||
|
||||
def delete_row(self, state: str) -> None:
|
||||
with self.db:
|
||||
self.db.execute(sqlite_requests.delete_by_state, {'state': state})
|
||||
|
||||
def set_tokens(self, access_token: str, refresh_token: str, expires_in: int, timestamp_got_expires_in: int,
|
||||
state: str) -> None:
|
||||
|
||||
with self.db:
|
||||
self.db.execute(
|
||||
sqlite_requests.set_tokens_by_state,
|
||||
{
|
||||
'access_token': access_token,
|
||||
'refresh_token': refresh_token,
|
||||
'expires_in': expires_in,
|
||||
'timestamp_got_expires_in': timestamp_got_expires_in,
|
||||
'state': state
|
||||
})
|
||||
|
||||
def set_nickname(self, nickname: str, state: str) -> bool:
|
||||
"""
|
||||
Return True if inserted successfully, False if catch sqlite3.IntegrityError
|
||||
|
||||
:param nickname:
|
||||
:param state:
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
with self.db:
|
||||
self.db.execute(
|
||||
sqlite_requests.set_nickname_by_state,
|
||||
{'nickname': nickname, 'state': state}
|
||||
)
|
||||
return True
|
||||
|
||||
except sqlite3.IntegrityError:
|
||||
"""
|
||||
let's migrate new received data to old state
|
||||
1. Get old state by nickname
|
||||
2. Remove row with old state
|
||||
3. Set old state where new state
|
||||
"""
|
||||
state_to_set: str = self.get_state_by_nickname(nickname)
|
||||
self.delete_row(state_to_set)
|
||||
self.set_new_state(state_to_set, state)
|
||||
with self.db:
|
||||
self.db.execute(
|
||||
sqlite_requests.set_nickname_by_state,
|
||||
{'nickname': nickname, 'state': state_to_set}
|
||||
)
|
||||
return False
|
||||
|
||||
def get_state_by_nickname(self, nickname: str) -> Union[None, str]:
|
||||
with self.db:
|
||||
nickname_f1 = self.db.execute(sqlite_requests.get_state_by_nickname, {'nickname': nickname}).fetchone()
|
||||
|
||||
if nickname_f1 is None:
|
||||
return None
|
||||
|
||||
return nickname_f1['state']
|
||||
|
||||
def set_new_state(self, new_state: str, old_state: str) -> None:
|
||||
with self.db:
|
||||
self.db.execute(sqlite_requests.update_state_by_state, {'new_state': new_state, 'state': old_state})
|
||||
|
||||
def get_row(self, state: str) -> dict:
|
||||
return self.db.execute(sqlite_requests.select_all_by_state, {'state': state}).fetchone()
|
||||
|
||||
def increment_refresh_tries(self, state: str) -> None:
|
||||
with self.db:
|
||||
self.db.execute(sqlite_requests.refresh_times_increment, {'state': state})
|
||||
|
||||
def get_token_for_user(self, state: str) -> dict:
|
||||
return self.db.execute(sqlite_requests.get_token_for_user, {'state': state}).fetchone()
|
||||
|
||||
def list_all_records(self) -> list:
|
||||
return self.db.execute(sqlite_requests.select_nickname_state_all).fetchall()
|
48
capi/sqlite_requests.py
Normal file
48
capi/sqlite_requests.py
Normal file
@ -0,0 +1,48 @@
|
||||
schema = """create table if not exists authorizations (
|
||||
code_verifier text ,
|
||||
state text,
|
||||
timestamp_init datetime default current_timestamp ,
|
||||
code text,
|
||||
access_token text,
|
||||
refresh_token text,
|
||||
expires_in text,
|
||||
timestamp_got_expires_in text,
|
||||
nickname text unique,
|
||||
refresh_tries int default 0
|
||||
);"""
|
||||
|
||||
insert_auth_init = """insert into authorizations
|
||||
(code_verifier, state)
|
||||
values
|
||||
(:code_verifier, :state);"""
|
||||
|
||||
select_all_by_state = """select * from authorizations where state = :state;"""
|
||||
|
||||
set_code_state = """update authorizations set code = :code where state = :state;"""
|
||||
|
||||
delete_by_state = """delete from authorizations where state = :state;"""
|
||||
|
||||
set_tokens_by_state = """update authorizations
|
||||
set
|
||||
access_token = :access_token,
|
||||
refresh_token = :refresh_token,
|
||||
expires_in = :expires_in,
|
||||
timestamp_got_expires_in = :timestamp_got_expires_in,
|
||||
refresh_tries = 0
|
||||
where state = :state;"""
|
||||
|
||||
set_nickname_by_state = """update authorizations set nickname = :nickname where state = :state;"""
|
||||
|
||||
get_state_by_nickname = """select state from authorizations where nickname = :nickname;"""
|
||||
|
||||
update_state_by_state = """update authorizations set state = :new_state where state = :state;"""
|
||||
|
||||
refresh_times_increment = """update authorizations set refresh_tries = refresh_tries + 1 where state = :state;"""
|
||||
|
||||
get_token_for_user = """select
|
||||
access_token,
|
||||
timestamp_got_expires_in + expires_in as expires_on,
|
||||
nickname
|
||||
from authorizations where state = :state;"""
|
||||
|
||||
select_nickname_state_all = """select nickname, state from authorizations where nickname is not null;"""
|
50
capi/utils.py
Normal file
50
capi/utils.py
Normal file
@ -0,0 +1,50 @@
|
||||
import hashlib
|
||||
import base64
|
||||
import requests
|
||||
import config
|
||||
|
||||
|
||||
def get_tokens_request(code, code_verifier) -> requests.Response:
|
||||
"""
|
||||
Performs initial requesting access and refresh tokens
|
||||
|
||||
:param code:
|
||||
:param code_verifier:
|
||||
:return:
|
||||
"""
|
||||
token_request: requests.Response = requests.post(
|
||||
url=config.TOKEN_URL,
|
||||
headers={
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'User-Agent': config.PROPER_USER_AGENT
|
||||
},
|
||||
data=f'redirect_uri={config.REDIRECT_URL}&'
|
||||
f'code={code}&'
|
||||
f'grant_type=authorization_code&'
|
||||
f'code_verifier={code_verifier}&'
|
||||
f'client_id={config.CLIENT_ID}')
|
||||
|
||||
return token_request
|
||||
|
||||
|
||||
def get_nickname(access_token: str) -> str:
|
||||
return requests.get(
|
||||
url='https://companion.orerve.net/profile',
|
||||
headers={'Authorization': f'Bearer {access_token}',
|
||||
'User-Agent': config.PROPER_USER_AGENT}).json()["commander"]["name"]
|
||||
|
||||
|
||||
def refresh_request(refresh_token: str) -> requests.Response:
|
||||
return requests.post(
|
||||
url=config.TOKEN_URL,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
data=f'grant_type=refresh_token&client_id={config.CLIENT_ID}&refresh_token={refresh_token}')
|
||||
|
||||
|
||||
def generate_challenge(_verifier: bytes) -> str:
|
||||
"""
|
||||
It takes code verifier and return code challenge
|
||||
:param _verifier:
|
||||
:return:
|
||||
"""
|
||||
return base64.urlsafe_b64encode(hashlib.sha256(_verifier).digest())[:-1].decode('utf-8')
|
@ -4,10 +4,12 @@ from os import getenv
|
||||
CLIENT_ID = getenv('client_id')
|
||||
assert CLIENT_ID, "No client_id in env"
|
||||
|
||||
log_level = 'DEBUG'
|
||||
|
||||
REDIRECT_URL = requests.utils.quote("http://127.0.0.1:9000/fdev-redirect")
|
||||
AUTH_URL = 'https://auth.frontierstore.net/auth'
|
||||
TOKEN_URL = 'https://auth.frontierstore.net/token'
|
||||
PROPER_USER_AGENT = 'EDCD-a31-0.1'
|
||||
PROPER_USER_AGENT = 'EDCD-a31-0.2'
|
||||
REDIRECT_HTML_TEMPLATE = """
|
||||
<!DOCTYPE HTML>
|
||||
<html>
|
||||
|
36
legacy/config.py
Normal file
36
legacy/config.py
Normal file
@ -0,0 +1,36 @@
|
||||
import requests
|
||||
from os import getenv
|
||||
|
||||
CLIENT_ID = getenv('client_id')
|
||||
assert CLIENT_ID, "No client_id in env"
|
||||
|
||||
REDIRECT_URL = requests.utils.quote("http://127.0.0.1:9000/fdev-redirect")
|
||||
AUTH_URL = 'https://auth.frontierstore.net/auth'
|
||||
TOKEN_URL = 'https://auth.frontierstore.net/token'
|
||||
PROPER_USER_AGENT = 'EDCD-a31-0.1'
|
||||
REDIRECT_HTML_TEMPLATE = """
|
||||
<!DOCTYPE HTML>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="refresh" content="0; url={link}">
|
||||
</head>
|
||||
<body>
|
||||
<a href="{link}">
|
||||
You should be redirected shortly...
|
||||
</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
ADMIN_USERS_TEMPLATE = """
|
||||
<!DOCTYPE HTML>
|
||||
<html>
|
||||
<head>
|
||||
</head>
|
||||
<body>
|
||||
{}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
ADMIN_USER_TEMPLATE = '<a href="{link}">{desc}</a>'
|
@ -37,7 +37,7 @@ logger.addHandler(stdout_handler)
|
||||
Typical workflow:
|
||||
1. User open Authorize endpoint
|
||||
2. Authorize building link for FDEV's /auth, sending link to user
|
||||
3. Authorize write code_verifier, state, timestamp_init to DB
|
||||
3. Authorize write code_verifier, state, timestamp_init (by DBMS) to DB
|
||||
4. User approving client, redirecting to FDEV_redirect endpoint
|
||||
5. Searching in DB if we have record with this state
|
||||
if don't have:
|
29
web.py
Normal file
29
web.py
Normal file
@ -0,0 +1,29 @@
|
||||
import falcon
|
||||
import waitress
|
||||
import json
|
||||
|
||||
from capi import capi_authorizer
|
||||
import config
|
||||
|
||||
|
||||
class AuthInit:
|
||||
def on_get(self, req: falcon.request.Request, resp: falcon.response.Response) -> None:
|
||||
resp.content_type = falcon.MEDIA_HTML
|
||||
resp.text = config.REDIRECT_HTML_TEMPLATE.format(link=capi_authorizer.auth_init())
|
||||
|
||||
|
||||
class FDEVCallback:
|
||||
def on_get(self, req: falcon.request.Request, resp: falcon.response.Response) -> None:
|
||||
code = req.get_param('code')
|
||||
state = req.get_param('state')
|
||||
msg = capi_authorizer.fdev_callback(code, state)
|
||||
resp.content_type = falcon.MEDIA_JSON
|
||||
resp.text = json.dumps(msg)
|
||||
|
||||
|
||||
application = falcon.App()
|
||||
application.add_route('/authorize', AuthInit())
|
||||
application.add_route('/fdev-redirect', FDEVCallback())
|
||||
|
||||
if __name__ == '__main__':
|
||||
waitress.serve(application, host='127.0.0.1', port=9000)
|
Loading…
x
Reference in New Issue
Block a user