mirror of
https://github.com/norohind/FDEV-CAPI-Handler.git
synced 2025-06-03 09:00:58 +03:00
WIP: full refactor
This commit is contained in:
parent
7a51aaeb67
commit
bbeec2e3fe
447
EDMCLogging.py
Normal file
447
EDMCLogging.py
Normal file
@ -0,0 +1,447 @@
|
|||||||
|
"""
|
||||||
|
Set up required logging for the application.
|
||||||
|
|
||||||
|
This module provides for a common logging-powered log facility.
|
||||||
|
Mostly it implements a logging.Filter() in order to get two extra
|
||||||
|
members on the logging.LogRecord instance for use in logging.Formatter()
|
||||||
|
strings.
|
||||||
|
|
||||||
|
If type checking, e.g. mypy, objects to `logging.trace(...)` then include this
|
||||||
|
stanza:
|
||||||
|
|
||||||
|
# See EDMCLogging.py docs.
|
||||||
|
# isort: off
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from logging import trace, TRACE # type: ignore # noqa: F401
|
||||||
|
# isort: on
|
||||||
|
|
||||||
|
This is needed because we add the TRACE level and the trace() function
|
||||||
|
ourselves at runtime.
|
||||||
|
|
||||||
|
To utilise logging in core code, or internal plugins, include this:
|
||||||
|
|
||||||
|
from EDMCLogging import get_main_logger
|
||||||
|
|
||||||
|
logger = get_main_logger()
|
||||||
|
|
||||||
|
To utilise logging in a 'found' (third-party) plugin, include this:
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
plugin_name = os.path.basename(os.path.dirname(__file__))
|
||||||
|
# plugin_name here *must* be the name of the folder the plugin resides in
|
||||||
|
# See, plug.py:load_plugins()
|
||||||
|
logger = logging.getLogger(f'{appname}.{plugin_name}')
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
from contextlib import suppress
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
# So that any warning about accessing a protected member is only in one place.
|
||||||
|
from sys import _getframe as getframe
|
||||||
|
from threading import get_native_id as thread_native_id
|
||||||
|
from traceback import print_exc
|
||||||
|
from typing import TYPE_CHECKING, Tuple, cast
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
# TODO: Tests:
|
||||||
|
#
|
||||||
|
# 1. Call from bare function in file.
|
||||||
|
# 2. Call from `if __name__ == "__main__":` section
|
||||||
|
#
|
||||||
|
# 3. Call from 1st level function in 1st level Class in file
|
||||||
|
# 4. Call from 2nd level function in 1st level Class in file
|
||||||
|
# 5. Call from 3rd level function in 1st level Class in file
|
||||||
|
#
|
||||||
|
# 6. Call from 1st level function in 2nd level Class in file
|
||||||
|
# 7. Call from 2nd level function in 2nd level Class in file
|
||||||
|
# 8. Call from 3rd level function in 2nd level Class in file
|
||||||
|
#
|
||||||
|
# 9. Call from 1st level function in 3rd level Class in file
|
||||||
|
# 10. Call from 2nd level function in 3rd level Class in file
|
||||||
|
# 11. Call from 3rd level function in 3rd level Class in file
|
||||||
|
#
|
||||||
|
# 12. Call from 2nd level file, all as above.
|
||||||
|
#
|
||||||
|
# 13. Call from *module*
|
||||||
|
#
|
||||||
|
# 14. Call from *package*
|
||||||
|
|
||||||
|
_default_loglevel = logging.DEBUG
|
||||||
|
|
||||||
|
# Define a TRACE level
|
||||||
|
LEVEL_TRACE = 5
|
||||||
|
LEVEL_TRACE_ALL = 3
|
||||||
|
logging.addLevelName(LEVEL_TRACE, "TRACE")
|
||||||
|
logging.addLevelName(LEVEL_TRACE_ALL, "TRACE_ALL")
|
||||||
|
logging.TRACE = LEVEL_TRACE # type: ignore
|
||||||
|
logging.TRACE_ALL = LEVEL_TRACE_ALL # type: ignore
|
||||||
|
logging.Logger.trace = lambda self, message, *args, **kwargs: self._log( # type: ignore
|
||||||
|
logging.TRACE, # type: ignore
|
||||||
|
message,
|
||||||
|
args,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_if(self: logging.Logger, condition: str, message: str, *args, **kwargs) -> None:
|
||||||
|
if any(fnmatch(condition, p) for p in []):
|
||||||
|
self._log(logging.TRACE, message, args, **kwargs) # type: ignore # we added it
|
||||||
|
return
|
||||||
|
|
||||||
|
self._log(logging.TRACE_ALL, message, args, **kwargs) # type: ignore # we added it
|
||||||
|
|
||||||
|
|
||||||
|
logging.Logger.trace_if = _trace_if # type: ignore
|
||||||
|
|
||||||
|
# we cant hide this from `from xxx` imports and I'd really rather no-one other than `logging` had access to it
|
||||||
|
del _trace_if
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from types import FrameType
|
||||||
|
|
||||||
|
# Fake type that we can use here to tell type checkers that trace exists
|
||||||
|
|
||||||
|
class LoggerMixin(logging.Logger):
|
||||||
|
"""LoggerMixin is a fake class that tells type checkers that trace exists on a given type."""
|
||||||
|
|
||||||
|
def trace(self, message, *args, **kwargs) -> None:
|
||||||
|
"""See implementation above."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def trace_if(self, condition: str, message, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Fake trace if method, traces only if condition exists in trace_on.
|
||||||
|
|
||||||
|
See implementation above.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
"""
|
||||||
|
Wrapper class for all logging configuration and code.
|
||||||
|
|
||||||
|
Class instantiation requires the 'logger name' and optional loglevel.
|
||||||
|
It is intended that this 'logger name' be re-used in all files/modules
|
||||||
|
that need to log.
|
||||||
|
|
||||||
|
Users of this class should then call getLogger() to get the
|
||||||
|
logging.Logger instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger_name: str, loglevel: int = _default_loglevel):
|
||||||
|
"""
|
||||||
|
Set up a `logging.Logger` with our preferred configuration.
|
||||||
|
|
||||||
|
This includes using an EDMCContextFilter to add 'class' and 'qualname'
|
||||||
|
expansions for logging.Formatter().
|
||||||
|
"""
|
||||||
|
self.logger = logging.getLogger(logger_name)
|
||||||
|
# Configure the logging.Logger
|
||||||
|
# This needs to always be TRACE in order to let TRACE level messages
|
||||||
|
# through to check the *handler* levels.
|
||||||
|
self.logger.setLevel(logging.TRACE) # type: ignore
|
||||||
|
|
||||||
|
# Set up filter for adding class name
|
||||||
|
self.logger_filter = EDMCContextFilter()
|
||||||
|
self.logger.addFilter(self.logger_filter)
|
||||||
|
|
||||||
|
# Our basic channel handling stdout
|
||||||
|
self.logger_channel = logging.StreamHandler()
|
||||||
|
# This should be affected by the user configured log level
|
||||||
|
self.logger_channel.setLevel(loglevel)
|
||||||
|
|
||||||
|
self.logger_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(process)d:%(thread)d:%(osthreadid)d %(module)s.%(qualname)s:%(lineno)d: %(message)s') # noqa: E501
|
||||||
|
self.logger_formatter.default_time_format = '%Y-%m-%d %H:%M:%S'
|
||||||
|
self.logger_formatter.default_msec_format = '%s.%03d'
|
||||||
|
|
||||||
|
self.logger_channel.setFormatter(self.logger_formatter)
|
||||||
|
self.logger.addHandler(self.logger_channel)
|
||||||
|
|
||||||
|
def get_logger(self) -> 'LoggerMixin':
|
||||||
|
"""
|
||||||
|
Obtain the self.logger of the class instance.
|
||||||
|
|
||||||
|
Not to be confused with logging.getLogger().
|
||||||
|
"""
|
||||||
|
return cast('LoggerMixin', self.logger)
|
||||||
|
|
||||||
|
def get_streamhandler(self) -> logging.Handler:
|
||||||
|
"""
|
||||||
|
Obtain the self.logger_channel StreamHandler instance.
|
||||||
|
|
||||||
|
:return: logging.StreamHandler
|
||||||
|
"""
|
||||||
|
return self.logger_channel
|
||||||
|
|
||||||
|
def set_channels_loglevel(self, level: int) -> None:
|
||||||
|
"""
|
||||||
|
Set the specified log level on the channels.
|
||||||
|
|
||||||
|
:param level: A valid `logging` level.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
self.logger_channel.setLevel(level)
|
||||||
|
|
||||||
|
def set_console_loglevel(self, level: int) -> None:
|
||||||
|
"""
|
||||||
|
Set the specified log level on the console channel.
|
||||||
|
|
||||||
|
:param level: A valid `logging` level.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if self.logger_channel.level != logging.TRACE: # type: ignore
|
||||||
|
self.logger_channel.setLevel(level)
|
||||||
|
else:
|
||||||
|
logger.trace("Not changing log level because it's TRACE") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class EDMCContextFilter(logging.Filter):
|
||||||
|
"""
|
||||||
|
Implements filtering to add extra format specifiers, and tweak others.
|
||||||
|
|
||||||
|
logging.Filter sub-class to place extra attributes of the calling site
|
||||||
|
into the record.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
"""
|
||||||
|
Attempt to set/change fields in the LogRecord.
|
||||||
|
|
||||||
|
1. class = class name(s) of the call site, if applicable
|
||||||
|
2. qualname = __qualname__ of the call site. This simplifies
|
||||||
|
logging.Formatter() as you can use just this no matter if there is
|
||||||
|
a class involved or not, so you get a nice clean:
|
||||||
|
<file/module>.<classA>[.classB....].<function>
|
||||||
|
3. osthreadid = OS level thread ID.
|
||||||
|
|
||||||
|
If we fail to be able to properly set either then:
|
||||||
|
|
||||||
|
1. Use print() to alert, to be SURE a message is seen.
|
||||||
|
2. But also return strings noting the error, so there'll be
|
||||||
|
something in the log output if it happens.
|
||||||
|
|
||||||
|
:param record: The LogRecord we're "filtering"
|
||||||
|
:return: bool - Always true in order for this record to be logged.
|
||||||
|
"""
|
||||||
|
(class_name, qualname, module_name) = self.caller_attributes(module_name=getattr(record, 'module'))
|
||||||
|
|
||||||
|
# Only set if we got a useful value
|
||||||
|
if module_name:
|
||||||
|
setattr(record, 'module', module_name)
|
||||||
|
|
||||||
|
# Only set if not already provided by logging itself
|
||||||
|
if getattr(record, 'class', None) is None:
|
||||||
|
setattr(record, 'class', class_name)
|
||||||
|
|
||||||
|
# Only set if not already provided by logging itself
|
||||||
|
if getattr(record, 'qualname', None) is None:
|
||||||
|
setattr(record, 'qualname', qualname)
|
||||||
|
|
||||||
|
setattr(record, 'osthreadid', thread_native_id())
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def caller_attributes(cls, module_name: str = '') -> Tuple[str, str, str]: # noqa: CCR001, E501, C901 # this is as refactored as is sensible
|
||||||
|
"""
|
||||||
|
Determine extra or changed fields for the caller.
|
||||||
|
|
||||||
|
1. qualname finds the relevant object and its __qualname__
|
||||||
|
2. caller_class_names is just the full class names of the calling
|
||||||
|
class if relevant.
|
||||||
|
3. module is munged if we detect the caller is an EDMC plugin,
|
||||||
|
whether internal or found.
|
||||||
|
|
||||||
|
:param module_name: The name of the calling module.
|
||||||
|
:return: Tuple[str, str, str] - class_name, qualname, module_name
|
||||||
|
"""
|
||||||
|
frame = cls.find_caller_frame()
|
||||||
|
|
||||||
|
caller_qualname = caller_class_names = ''
|
||||||
|
if frame:
|
||||||
|
# <https://stackoverflow.com/questions/2203424/python-how-to-retrieve-class-information-from-a-frame-object#2220759>
|
||||||
|
try:
|
||||||
|
frame_info = inspect.getframeinfo(frame)
|
||||||
|
# raise(IndexError) # TODO: Remove, only for testing
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Separate from the print below to guarantee we see at least this much.
|
||||||
|
print('EDMCLogging:EDMCContextFilter:caller_attributes(): Failed in `inspect.getframinfo(frame)`')
|
||||||
|
|
||||||
|
# We want to *attempt* to show something about the nature of 'frame',
|
||||||
|
# but at this point we can't trust it will work.
|
||||||
|
try:
|
||||||
|
print(f'frame: {frame}')
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# We've given up, so just return '??' to signal we couldn't get the info
|
||||||
|
return '??', '??', module_name
|
||||||
|
try:
|
||||||
|
args, _, _, value_dict = inspect.getargvalues(frame)
|
||||||
|
if len(args) and args[0] in ('self', 'cls'):
|
||||||
|
frame_class: 'object' = value_dict[args[0]]
|
||||||
|
|
||||||
|
if frame_class:
|
||||||
|
# See https://en.wikipedia.org/wiki/Name_mangling#Python for how name mangling works.
|
||||||
|
# For more detail, see _Py_Mangle in CPython's Python/compile.c.
|
||||||
|
name = frame_info.function
|
||||||
|
class_name = frame_class.__class__.__name__.lstrip("_")
|
||||||
|
if name.startswith("__") and not name.endswith("__") and class_name:
|
||||||
|
name = f'_{class_name}{frame_info.function}'
|
||||||
|
|
||||||
|
# Find __qualname__ of the caller
|
||||||
|
fn = inspect.getattr_static(frame_class, name, None)
|
||||||
|
if fn is None:
|
||||||
|
# For some reason getattr_static cant grab this. Try and grab it with getattr, bail out
|
||||||
|
# if we get a RecursionError indicating a property
|
||||||
|
try:
|
||||||
|
fn = getattr(frame_class, name, None)
|
||||||
|
except RecursionError:
|
||||||
|
print(
|
||||||
|
"EDMCLogging:EDMCContextFilter:caller_attributes():"
|
||||||
|
"Failed to get attribute for function info. Bailing out"
|
||||||
|
)
|
||||||
|
# class_name is better than nothing for __qualname__
|
||||||
|
return class_name, class_name, module_name
|
||||||
|
|
||||||
|
if fn is not None:
|
||||||
|
if isinstance(fn, property):
|
||||||
|
class_name = str(frame_class)
|
||||||
|
# If somehow you make your __class__ or __class__.__qualname__ recursive,
|
||||||
|
# I'll be impressed.
|
||||||
|
if hasattr(frame_class, '__class__') and hasattr(frame_class.__class__, "__qualname__"):
|
||||||
|
class_name = frame_class.__class__.__qualname__
|
||||||
|
caller_qualname = f"{class_name}.{name}(property)"
|
||||||
|
|
||||||
|
else:
|
||||||
|
caller_qualname = f"<property {name} on {class_name}>"
|
||||||
|
|
||||||
|
elif not hasattr(fn, '__qualname__'):
|
||||||
|
caller_qualname = name
|
||||||
|
|
||||||
|
elif hasattr(fn, '__qualname__') and fn.__qualname__:
|
||||||
|
caller_qualname = fn.__qualname__
|
||||||
|
|
||||||
|
# Find containing class name(s) of caller, if any
|
||||||
|
if (
|
||||||
|
frame_class.__class__ and hasattr(frame_class.__class__, '__qualname__')
|
||||||
|
and frame_class.__class__.__qualname__
|
||||||
|
):
|
||||||
|
caller_class_names = frame_class.__class__.__qualname__
|
||||||
|
|
||||||
|
# It's a call from the top level module file
|
||||||
|
elif frame_info.function == '<module>':
|
||||||
|
caller_class_names = '<none>'
|
||||||
|
caller_qualname = value_dict['__name__']
|
||||||
|
|
||||||
|
elif frame_info.function != '':
|
||||||
|
caller_class_names = '<none>'
|
||||||
|
caller_qualname = frame_info.function
|
||||||
|
|
||||||
|
module_name = cls.munge_module_name(frame_info, module_name)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print('ALERT! Something went VERY wrong in handling finding info to log')
|
||||||
|
print('ALERT! Information is as follows')
|
||||||
|
with suppress(Exception):
|
||||||
|
|
||||||
|
print(f'ALERT! {e=}')
|
||||||
|
print_exc()
|
||||||
|
print(f'ALERT! {frame=}')
|
||||||
|
with suppress(Exception):
|
||||||
|
print(f'ALERT! {fn=}') # type: ignore
|
||||||
|
with suppress(Exception):
|
||||||
|
print(f'ALERT! {cls=}')
|
||||||
|
|
||||||
|
finally: # Ensure this always happens
|
||||||
|
# https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack
|
||||||
|
del frame
|
||||||
|
|
||||||
|
if caller_qualname == '':
|
||||||
|
print('ALERT! Something went wrong with finding caller qualname for logging!')
|
||||||
|
caller_qualname = '<ERROR in EDMCLogging.caller_class_and_qualname() for "qualname">'
|
||||||
|
|
||||||
|
if caller_class_names == '':
|
||||||
|
print('ALERT! Something went wrong with finding caller class name(s) for logging!')
|
||||||
|
caller_class_names = '<ERROR in EDMCLogging.caller_class_and_qualname() for "class">'
|
||||||
|
|
||||||
|
return caller_class_names, caller_qualname, module_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find_caller_frame(cls):
|
||||||
|
"""
|
||||||
|
Find the stack frame of the logging caller.
|
||||||
|
|
||||||
|
:returns: 'frame' object such as from sys._getframe()
|
||||||
|
"""
|
||||||
|
# Go up through stack frames until we find the first with a
|
||||||
|
# type(f_locals.self) of logging.Logger. This should be the start
|
||||||
|
# of the frames internal to logging.
|
||||||
|
frame: 'FrameType' = getframe(0)
|
||||||
|
while frame:
|
||||||
|
if isinstance(frame.f_locals.get('self'), logging.Logger):
|
||||||
|
frame = cast('FrameType', frame.f_back) # Want to start on the next frame below
|
||||||
|
break
|
||||||
|
frame = cast('FrameType', frame.f_back)
|
||||||
|
# Now continue up through frames until we find the next one where
|
||||||
|
# that is *not* true, as it should be the call site of the logger
|
||||||
|
# call
|
||||||
|
while frame:
|
||||||
|
if not isinstance(frame.f_locals.get('self'), logging.Logger):
|
||||||
|
break # We've found the frame we want
|
||||||
|
frame = cast('FrameType', frame.f_back)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def munge_module_name(cls, frame_info: inspect.Traceback, module_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Adjust module_name based on the file path for the given frame.
|
||||||
|
|
||||||
|
We want to distinguish between other code and both our internal plugins
|
||||||
|
and the 'found' ones.
|
||||||
|
|
||||||
|
For internal plugins we want "plugins.<filename>".
|
||||||
|
For 'found' plugins we want "<plugins>.<plugin_name>...".
|
||||||
|
|
||||||
|
:param frame_info: The frame_info of the caller.
|
||||||
|
:param module_name: The module_name string to munge.
|
||||||
|
:return: The munged module_name.
|
||||||
|
"""
|
||||||
|
# file_name = pathlib.Path(frame_info.filename).expanduser()
|
||||||
|
# plugin_dir = pathlib.Path(config.plugin_dir_path).expanduser()
|
||||||
|
# internal_plugin_dir = pathlib.Path(config.internal_plugin_dir_path).expanduser()
|
||||||
|
# if internal_plugin_dir in file_name.parents:
|
||||||
|
# # its an internal plugin
|
||||||
|
# return f'plugins.{".".join(file_name.relative_to(internal_plugin_dir).parent.parts)}'
|
||||||
|
|
||||||
|
# elif plugin_dir in file_name.parents:
|
||||||
|
# return f'<plugin>.{".".join(file_name.relative_to(plugin_dir).parent.parts)}'
|
||||||
|
|
||||||
|
return module_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_main_logger(sublogger_name: str = '') -> 'LoggerMixin':
|
||||||
|
"""Return the correct logger for how the program is being run. (outdated)"""
|
||||||
|
# if not os.getenv("EDMC_NO_UI"):
|
||||||
|
# # GUI app being run
|
||||||
|
# return cast('LoggerMixin', logging.getLogger(appname))
|
||||||
|
# else:
|
||||||
|
# # Must be the CLI
|
||||||
|
# return cast('LoggerMixin', logging.getLogger(appcmdname))
|
||||||
|
return cast('LoggerMixin', logging.getLogger(__name__))
|
||||||
|
|
||||||
|
# Singleton
|
||||||
|
loglevel = logging._nameToLevel.get(config.log_level, logging.DEBUG) # noqa:
|
||||||
|
|
||||||
|
base_logger_name = __name__
|
||||||
|
|
||||||
|
edmclogger = Logger(base_logger_name, loglevel=loglevel)
|
||||||
|
logger: 'LoggerMixin' = edmclogger.get_logger()
|
167
capi/__init__.py
Normal file
167
capi/__init__.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from . import model
|
||||||
|
from . import utils
|
||||||
|
from . import exceptions
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import config
|
||||||
|
from EDMCLogging import get_main_logger
|
||||||
|
|
||||||
|
logger = get_main_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class CAPIAuthorizer:
|
||||||
|
def __init__(self, _model: model.Model):
|
||||||
|
self.model: model.Model = _model
|
||||||
|
|
||||||
|
def auth_init(self) -> str:
|
||||||
|
"""
|
||||||
|
Generates initial url for fdev auth
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
code_verifier = base64.urlsafe_b64encode(os.urandom(32))
|
||||||
|
code_challenge = utils.generate_challenge(code_verifier)
|
||||||
|
state_string = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8")
|
||||||
|
|
||||||
|
self.model.auth_init(code_verifier.decode('utf-8'), state_string)
|
||||||
|
|
||||||
|
redirect_user_to_fdev = f"{config.AUTH_URL}?" \
|
||||||
|
f"audience=all&" \
|
||||||
|
f"scope=capi&" \
|
||||||
|
f"response_type=code&" \
|
||||||
|
f"client_id={config.CLIENT_ID}&" \
|
||||||
|
f"code_challenge={code_challenge}&" \
|
||||||
|
f"code_challenge_method=S256&" \
|
||||||
|
f"state={state_string}&" \
|
||||||
|
f"redirect_uri={config.REDIRECT_URL}"
|
||||||
|
|
||||||
|
return redirect_user_to_fdev
|
||||||
|
|
||||||
|
def fdev_callback(self, code: str, state: str) -> dict[str, str]:
|
||||||
|
if type(code) is not str or type(state) != str:
|
||||||
|
raise TypeError('code and state must be strings')
|
||||||
|
|
||||||
|
code_verifier = self.model.get_verifier(state)
|
||||||
|
|
||||||
|
self.model.set_code(code, state)
|
||||||
|
|
||||||
|
token_request = utils.get_tokens_request(code, code_verifier)
|
||||||
|
|
||||||
|
try:
|
||||||
|
token_request.raise_for_status()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f'token_request failed. Status: {token_request.status_code!r}, text: {token_request.text!r}',
|
||||||
|
exc_info=e
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.model.delete_row(state)
|
||||||
|
|
||||||
|
raise e
|
||||||
|
|
||||||
|
tokens = token_request.json()
|
||||||
|
|
||||||
|
access_token = tokens["access_token"]
|
||||||
|
refresh_token = tokens["refresh_token"]
|
||||||
|
expires_in = tokens["expires_in"]
|
||||||
|
timestamp_got_expires_in = int(time.time())
|
||||||
|
|
||||||
|
self.model.set_tokens(access_token, refresh_token, expires_in, timestamp_got_expires_in, state)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nickname = utils.get_nickname(access_token)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Couldn't get nickname for state: {state!r}", exc_info=e)
|
||||||
|
raise KeyError(f"Couldn't get nickname for state: {state!r}")
|
||||||
|
|
||||||
|
msg = {'status': 'ok', 'description': '', 'state': ''}
|
||||||
|
|
||||||
|
if not self.model.set_nickname(nickname, state):
|
||||||
|
msg['description'] = 'Tokens updated'
|
||||||
|
|
||||||
|
else:
|
||||||
|
msg['description'] = 'Tokens saved'
|
||||||
|
|
||||||
|
msg['state'] = self.model.get_state_by_nickname(nickname)
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def get_token_by_state(self, state: str) -> dict:
|
||||||
|
self.refresh_by_state(state)
|
||||||
|
row = self.model.get_token_for_user(state)
|
||||||
|
row['expires_over'] = int(row['expires_on']) - int(time.time())
|
||||||
|
return row
|
||||||
|
|
||||||
|
def refresh_by_state(self, state: str, force_refresh=False, failure_tolerance=True) -> dict:
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param state:
|
||||||
|
:param force_refresh: if we should update token when its time hasn't come
|
||||||
|
:param failure_tolerance: if we shouldn't remove row in case of update's failure
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
msg = {'status': '', 'description': '', 'state': ''}
|
||||||
|
row = self.model.get_row(state)
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
# No such state in DB
|
||||||
|
msg['status'] = 'error'
|
||||||
|
msg['message'] = 'No such state in DB'
|
||||||
|
raise exceptions.RefreshFail(msg['description'], msg['status'], state)
|
||||||
|
|
||||||
|
msg['state'] = state
|
||||||
|
|
||||||
|
if int(time.time()) < int(row['timestamp_got_expires_in']) + int(row['expires_in']) and not force_refresh:
|
||||||
|
msg['status'] = 'ok'
|
||||||
|
msg['description'] = "Didn't refresh since it isn't required"
|
||||||
|
|
||||||
|
return msg # token isn't expired and we don't force updating
|
||||||
|
|
||||||
|
try:
|
||||||
|
refresh_request = utils.refresh_request(row['refresh_token'])
|
||||||
|
if refresh_request.status_code == 418: # Server's maintenance
|
||||||
|
logger.warning(f'FDEV maintenance 418, text: {refresh_request.text!r}')
|
||||||
|
msg['status'] = 'error'
|
||||||
|
msg['description'] = 'FDEV on maintenance'
|
||||||
|
raise exceptions.RefreshFail(msg['message'], msg['status'], state)
|
||||||
|
|
||||||
|
refresh_request.raise_for_status()
|
||||||
|
|
||||||
|
tokens = refresh_request.json()
|
||||||
|
tokens['timestamp_got_expires_in'] = int(time.time())
|
||||||
|
self.model.set_tokens(*tokens, state)
|
||||||
|
msg['status'] = 'ok'
|
||||||
|
msg['description'] = 'Token were successfully updated'
|
||||||
|
return msg
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# probably here something don't work
|
||||||
|
logger.warning(f'Fail on refreshing token for {state}, row:{row}')
|
||||||
|
msg['status'] = 'error'
|
||||||
|
|
||||||
|
self.model.increment_refresh_tries(state)
|
||||||
|
|
||||||
|
if not failure_tolerance:
|
||||||
|
if row['refresh_tries'] >= 5: # hardcoded limit
|
||||||
|
msg['description'] = "Refresh failed. You were removed from DB due to refresh rate limitings"
|
||||||
|
self.model.delete_row(state)
|
||||||
|
raise exceptions.RefreshFail(msg['description'], msg['status'], state)
|
||||||
|
|
||||||
|
msg['description'] = "Refresh failed. Try later"
|
||||||
|
raise exceptions.RefreshFail(msg['description'], msg['status'], state)
|
||||||
|
|
||||||
|
def delete_by_state(self, state: str) -> None:
|
||||||
|
self.model.delete_row(state)
|
||||||
|
|
||||||
|
def list_all_users(self) -> list[dict]:
|
||||||
|
return self.model.list_all_records()
|
||||||
|
|
||||||
|
|
||||||
|
capi_authorizer = CAPIAuthorizer(model.Model())
|
6
capi/exceptions.py
Normal file
6
capi/exceptions.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
class RefreshFail(Exception):
|
||||||
|
def __init__(self, message: str, status: str, state: str):
|
||||||
|
self.message = message
|
||||||
|
self.status = status
|
||||||
|
self.state = state
|
||||||
|
super().__init__(self.message + ' for ' + self.state)
|
117
capi/model.py
Normal file
117
capi/model.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import sqlite3
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from . import sqlite_requests
|
||||||
|
from EDMCLogging import get_main_logger
|
||||||
|
|
||||||
|
logger = get_main_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self):
|
||||||
|
self.db: sqlite3.Connection = sqlite3.connect('companion-api.sqlite', check_same_thread=False)
|
||||||
|
self.db.row_factory = lambda c, r: dict(zip([col[0] for col in c.description], r))
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(sqlite_requests.schema)
|
||||||
|
|
||||||
|
def auth_init(self, verifier: str, state: str) -> None:
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(
|
||||||
|
sqlite_requests.insert_auth_init,
|
||||||
|
{
|
||||||
|
'code_verifier': verifier,
|
||||||
|
'state': state
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_verifier(self, state: str) -> str:
|
||||||
|
code_verifier_req = self.db.execute(sqlite_requests.select_all_by_state, {'state': state}).fetchone()
|
||||||
|
|
||||||
|
if code_verifier_req is None:
|
||||||
|
# Somebody got here not by frontier redirect
|
||||||
|
raise KeyError('No state in DB found')
|
||||||
|
|
||||||
|
code_verifier: str = code_verifier_req['code_verifier']
|
||||||
|
|
||||||
|
return code_verifier
|
||||||
|
|
||||||
|
def set_code(self, code: str, state: str) -> None:
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(sqlite_requests.set_code_state, {'code': code, 'state': state})
|
||||||
|
|
||||||
|
def delete_row(self, state: str) -> None:
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(sqlite_requests.delete_by_state, {'state': state})
|
||||||
|
|
||||||
|
def set_tokens(self, access_token: str, refresh_token: str, expires_in: int, timestamp_got_expires_in: int,
|
||||||
|
state: str) -> None:
|
||||||
|
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(
|
||||||
|
sqlite_requests.set_tokens_by_state,
|
||||||
|
{
|
||||||
|
'access_token': access_token,
|
||||||
|
'refresh_token': refresh_token,
|
||||||
|
'expires_in': expires_in,
|
||||||
|
'timestamp_got_expires_in': timestamp_got_expires_in,
|
||||||
|
'state': state
|
||||||
|
})
|
||||||
|
|
||||||
|
def set_nickname(self, nickname: str, state: str) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if inserted successfully, False if catch sqlite3.IntegrityError
|
||||||
|
|
||||||
|
:param nickname:
|
||||||
|
:param state:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(
|
||||||
|
sqlite_requests.set_nickname_by_state,
|
||||||
|
{'nickname': nickname, 'state': state}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
"""
|
||||||
|
let's migrate new received data to old state
|
||||||
|
1. Get old state by nickname
|
||||||
|
2. Remove row with old state
|
||||||
|
3. Set old state where new state
|
||||||
|
"""
|
||||||
|
state_to_set: str = self.get_state_by_nickname(nickname)
|
||||||
|
self.delete_row(state_to_set)
|
||||||
|
self.set_new_state(state_to_set, state)
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(
|
||||||
|
sqlite_requests.set_nickname_by_state,
|
||||||
|
{'nickname': nickname, 'state': state_to_set}
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_state_by_nickname(self, nickname: str) -> Union[None, str]:
|
||||||
|
with self.db:
|
||||||
|
nickname_f1 = self.db.execute(sqlite_requests.get_state_by_nickname, {'nickname': nickname}).fetchone()
|
||||||
|
|
||||||
|
if nickname_f1 is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return nickname_f1['state']
|
||||||
|
|
||||||
|
def set_new_state(self, new_state: str, old_state: str) -> None:
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(sqlite_requests.update_state_by_state, {'new_state': new_state, 'state': old_state})
|
||||||
|
|
||||||
|
def get_row(self, state: str) -> dict:
|
||||||
|
return self.db.execute(sqlite_requests.select_all_by_state, {'state': state}).fetchone()
|
||||||
|
|
||||||
|
def increment_refresh_tries(self, state: str) -> None:
|
||||||
|
with self.db:
|
||||||
|
self.db.execute(sqlite_requests.refresh_times_increment, {'state': state})
|
||||||
|
|
||||||
|
def get_token_for_user(self, state: str) -> dict:
|
||||||
|
return self.db.execute(sqlite_requests.get_token_for_user, {'state': state}).fetchone()
|
||||||
|
|
||||||
|
def list_all_records(self) -> list:
|
||||||
|
return self.db.execute(sqlite_requests.select_nickname_state_all).fetchall()
|
48
capi/sqlite_requests.py
Normal file
48
capi/sqlite_requests.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
schema = """create table if not exists authorizations (
|
||||||
|
code_verifier text ,
|
||||||
|
state text,
|
||||||
|
timestamp_init datetime default current_timestamp ,
|
||||||
|
code text,
|
||||||
|
access_token text,
|
||||||
|
refresh_token text,
|
||||||
|
expires_in text,
|
||||||
|
timestamp_got_expires_in text,
|
||||||
|
nickname text unique,
|
||||||
|
refresh_tries int default 0
|
||||||
|
);"""
|
||||||
|
|
||||||
|
insert_auth_init = """insert into authorizations
|
||||||
|
(code_verifier, state)
|
||||||
|
values
|
||||||
|
(:code_verifier, :state);"""
|
||||||
|
|
||||||
|
select_all_by_state = """select * from authorizations where state = :state;"""
|
||||||
|
|
||||||
|
set_code_state = """update authorizations set code = :code where state = :state;"""
|
||||||
|
|
||||||
|
delete_by_state = """delete from authorizations where state = :state;"""
|
||||||
|
|
||||||
|
set_tokens_by_state = """update authorizations
|
||||||
|
set
|
||||||
|
access_token = :access_token,
|
||||||
|
refresh_token = :refresh_token,
|
||||||
|
expires_in = :expires_in,
|
||||||
|
timestamp_got_expires_in = :timestamp_got_expires_in,
|
||||||
|
refresh_tries = 0
|
||||||
|
where state = :state;"""
|
||||||
|
|
||||||
|
set_nickname_by_state = """update authorizations set nickname = :nickname where state = :state;"""
|
||||||
|
|
||||||
|
get_state_by_nickname = """select state from authorizations where nickname = :nickname;"""
|
||||||
|
|
||||||
|
update_state_by_state = """update authorizations set state = :new_state where state = :state;"""
|
||||||
|
|
||||||
|
refresh_times_increment = """update authorizations set refresh_tries = refresh_tries + 1 where state = :state;"""
|
||||||
|
|
||||||
|
get_token_for_user = """select
|
||||||
|
access_token,
|
||||||
|
timestamp_got_expires_in + expires_in as expires_on,
|
||||||
|
nickname
|
||||||
|
from authorizations where state = :state;"""
|
||||||
|
|
||||||
|
select_nickname_state_all = """select nickname, state from authorizations where nickname is not null;"""
|
50
capi/utils.py
Normal file
50
capi/utils.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_request(code, code_verifier) -> requests.Response:
|
||||||
|
"""
|
||||||
|
Performs initial requesting access and refresh tokens
|
||||||
|
|
||||||
|
:param code:
|
||||||
|
:param code_verifier:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
token_request: requests.Response = requests.post(
|
||||||
|
url=config.TOKEN_URL,
|
||||||
|
headers={
|
||||||
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
|
'User-Agent': config.PROPER_USER_AGENT
|
||||||
|
},
|
||||||
|
data=f'redirect_uri={config.REDIRECT_URL}&'
|
||||||
|
f'code={code}&'
|
||||||
|
f'grant_type=authorization_code&'
|
||||||
|
f'code_verifier={code_verifier}&'
|
||||||
|
f'client_id={config.CLIENT_ID}')
|
||||||
|
|
||||||
|
return token_request
|
||||||
|
|
||||||
|
|
||||||
|
def get_nickname(access_token: str) -> str:
|
||||||
|
return requests.get(
|
||||||
|
url='https://companion.orerve.net/profile',
|
||||||
|
headers={'Authorization': f'Bearer {access_token}',
|
||||||
|
'User-Agent': config.PROPER_USER_AGENT}).json()["commander"]["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_request(refresh_token: str) -> requests.Response:
|
||||||
|
return requests.post(
|
||||||
|
url=config.TOKEN_URL,
|
||||||
|
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||||
|
data=f'grant_type=refresh_token&client_id={config.CLIENT_ID}&refresh_token={refresh_token}')
|
||||||
|
|
||||||
|
|
||||||
|
def generate_challenge(_verifier: bytes) -> str:
|
||||||
|
"""
|
||||||
|
It takes code verifier and return code challenge
|
||||||
|
:param _verifier:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return base64.urlsafe_b64encode(hashlib.sha256(_verifier).digest())[:-1].decode('utf-8')
|
@ -4,10 +4,12 @@ from os import getenv
|
|||||||
CLIENT_ID = getenv('client_id')
|
CLIENT_ID = getenv('client_id')
|
||||||
assert CLIENT_ID, "No client_id in env"
|
assert CLIENT_ID, "No client_id in env"
|
||||||
|
|
||||||
|
log_level = 'DEBUG'
|
||||||
|
|
||||||
REDIRECT_URL = requests.utils.quote("http://127.0.0.1:9000/fdev-redirect")
|
REDIRECT_URL = requests.utils.quote("http://127.0.0.1:9000/fdev-redirect")
|
||||||
AUTH_URL = 'https://auth.frontierstore.net/auth'
|
AUTH_URL = 'https://auth.frontierstore.net/auth'
|
||||||
TOKEN_URL = 'https://auth.frontierstore.net/token'
|
TOKEN_URL = 'https://auth.frontierstore.net/token'
|
||||||
PROPER_USER_AGENT = 'EDCD-a31-0.1'
|
PROPER_USER_AGENT = 'EDCD-a31-0.2'
|
||||||
REDIRECT_HTML_TEMPLATE = """
|
REDIRECT_HTML_TEMPLATE = """
|
||||||
<!DOCTYPE HTML>
|
<!DOCTYPE HTML>
|
||||||
<html>
|
<html>
|
||||||
|
36
legacy/config.py
Normal file
36
legacy/config.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import requests
|
||||||
|
from os import getenv
|
||||||
|
|
||||||
|
CLIENT_ID = getenv('client_id')
|
||||||
|
assert CLIENT_ID, "No client_id in env"
|
||||||
|
|
||||||
|
REDIRECT_URL = requests.utils.quote("http://127.0.0.1:9000/fdev-redirect")
|
||||||
|
AUTH_URL = 'https://auth.frontierstore.net/auth'
|
||||||
|
TOKEN_URL = 'https://auth.frontierstore.net/token'
|
||||||
|
PROPER_USER_AGENT = 'EDCD-a31-0.1'
|
||||||
|
REDIRECT_HTML_TEMPLATE = """
|
||||||
|
<!DOCTYPE HTML>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta http-equiv="refresh" content="0; url={link}">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<a href="{link}">
|
||||||
|
You should be redirected shortly...
|
||||||
|
</a>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
ADMIN_USERS_TEMPLATE = """
|
||||||
|
<!DOCTYPE HTML>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
{}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
ADMIN_USER_TEMPLATE = '<a href="{link}">{desc}</a>'
|
@ -37,7 +37,7 @@ logger.addHandler(stdout_handler)
|
|||||||
Typical workflow:
|
Typical workflow:
|
||||||
1. User open Authorize endpoint
|
1. User open Authorize endpoint
|
||||||
2. Authorize building link for FDEV's /auth, sending link to user
|
2. Authorize building link for FDEV's /auth, sending link to user
|
||||||
3. Authorize write code_verifier, state, timestamp_init to DB
|
3. Authorize write code_verifier, state, timestamp_init (by DBMS) to DB
|
||||||
4. User approving client, redirecting to FDEV_redirect endpoint
|
4. User approving client, redirecting to FDEV_redirect endpoint
|
||||||
5. Searching in DB if we have record with this state
|
5. Searching in DB if we have record with this state
|
||||||
if don't have:
|
if don't have:
|
29
web.py
Normal file
29
web.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import falcon
|
||||||
|
import waitress
|
||||||
|
import json
|
||||||
|
|
||||||
|
from capi import capi_authorizer
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
class AuthInit:
|
||||||
|
def on_get(self, req: falcon.request.Request, resp: falcon.response.Response) -> None:
|
||||||
|
resp.content_type = falcon.MEDIA_HTML
|
||||||
|
resp.text = config.REDIRECT_HTML_TEMPLATE.format(link=capi_authorizer.auth_init())
|
||||||
|
|
||||||
|
|
||||||
|
class FDEVCallback:
|
||||||
|
def on_get(self, req: falcon.request.Request, resp: falcon.response.Response) -> None:
|
||||||
|
code = req.get_param('code')
|
||||||
|
state = req.get_param('state')
|
||||||
|
msg = capi_authorizer.fdev_callback(code, state)
|
||||||
|
resp.content_type = falcon.MEDIA_JSON
|
||||||
|
resp.text = json.dumps(msg)
|
||||||
|
|
||||||
|
|
||||||
|
application = falcon.App()
|
||||||
|
application.add_route('/authorize', AuthInit())
|
||||||
|
application.add_route('/fdev-redirect', FDEVCallback())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
waitress.serve(application, host='127.0.0.1', port=9000)
|
Loading…
x
Reference in New Issue
Block a user