diff --git a/webapi/fastapi_of_letcoing/controllers/auth_controller.py b/webapi/fastapi_of_letcoing/controllers/auth_controller.py index 99228da42f10daf9d2024e42e3dcf33464fb5b52..8b085fea20567a0fcbbd865df69305284542ca49 100644 --- a/webapi/fastapi_of_letcoing/controllers/auth_controller.py +++ b/webapi/fastapi_of_letcoing/controllers/auth_controller.py @@ -1,6 +1,6 @@ from flask import request, jsonify, url_for from flask_restx import Resource, Namespace, fields -from interfaces.service_interfaces import IConfigService, ILoggerService, IJWTService, IOIDCService +from interfaces.service_interfaces import IConfigService, ILoggerService, IJWTService, IOIDCService, IUserService from core.di_container import inject from models.auth_models import ( LoginRequest, LoginResponse, AuthCallbackRequest, @@ -87,11 +87,12 @@ class AuthCallbackController(Resource): @api.doc('auth_callback') @api.response(200, 'Success', auth_result_model) @api.response(400, 'Bad Request') - def get(self, provider: str): - """处理认证回调""" + async def get(self, provider: str): + """处理认证回调,实现注册和登录集成""" # 注入服务 oidc_service = inject(IOIDCService) jwt_service = inject(IJWTService) + user_service = inject(IUserService) # 验证提供商 if not oidc_service.validate_provider(provider): @@ -125,14 +126,29 @@ class AuthCallbackController(Resource): if not user_info_data: return {'success': False, 'error': '获取用户信息失败'}, 500 + # 获取provider_id + provider_id = user_info_data.get('id') + if not provider_id: + return {'success': False, 'error': '无法获取提供商用户ID'}, 500 + + # 查找或创建用户(支持注册和登录集成) + try: + user_data = await user_service.find_or_create_user( + provider=provider, + provider_id=provider_id, + user_info=user_info_data + ) + except Exception as e: + return {'success': False, 'error': f'用户处理失败: {str(e)}'}, 500 + # 创建用户信息对象 user_info = UserInfo( - id=user_info_data['id'], - username=user_info_data.get('username', ''), - email=user_info_data.get('email', ''), - name=user_info_data.get('name', ''), - avatar_url=user_info_data.get('avatar_url', ''), - provider=user_info_data.get('provider', provider) + id=user_data['id'], + username=user_data.get('username', ''), + email=user_data.get('email', ''), + name=user_data.get('name', ''), + avatar_url=user_data.get('avatar_url', ''), + provider=provider ) # 生成JWT令牌 @@ -150,7 +166,10 @@ class AuthCallbackController(Resource): return { 'success': True, 'user_info': user_info.to_dict(), - 'tokens': token_response.to_dict() + 'tokens': token_response.to_dict(), + 'is_new_user': not user_data.get('last_login') or ( + user_data.get('created_at') == user_data.get('updated_at') + ) }, 200 diff --git a/webapi/fastapi_of_letcoing/interfaces/service_interfaces.py b/webapi/fastapi_of_letcoing/interfaces/service_interfaces.py index f5a39e7694d75310601e29040c26c8351d59ce97..c264dc31c48f8e4b6b836124d225a7ac5f9905c2 100644 --- a/webapi/fastapi_of_letcoing/interfaces/service_interfaces.py +++ b/webapi/fastapi_of_letcoing/interfaces/service_interfaces.py @@ -271,4 +271,23 @@ class IJWTService(ABC): Returns: 是否成功 """ + pass + + +class IUserService(ABC): + """用户服务接口""" + + @abstractmethod + async def find_or_create_user(self, provider: str, provider_id: str, user_info: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据第三方登录信息查找或创建用户 + + Args: + provider: 登录提供商 (如 'github', 'api.xauat.site/sso/signup') + provider_id: 提供商的用户ID + user_info: 用户信息字典 + + Returns: + 用户信息字典 + """ pass \ No newline at end of file diff --git a/webapi/fastapi_of_letcoing/main.py b/webapi/fastapi_of_letcoing/main.py index e2c237fe0b04bd762d0ce58ca5e8f7e8085e5a75..b0c3eb14a1919fe9e0edf209af21f516c96524b4 100644 --- a/webapi/fastapi_of_letcoing/main.py +++ b/webapi/fastapi_of_letcoing/main.py @@ -55,12 +55,11 @@ else: setup_services(app.config) # 初始化 OIDC 服务 -oidc_service = setup_services.__globals__.get('oidc_service') -if not oidc_service: - from core.di_container import get_container, inject - container = get_container() - oidc_service = container.resolve(IOIDCService) - oidc_service.initialize_oauth(app) +from core.di_container import get_container +container = get_container() +oidc_service = container.resolve(IOIDCService) +oidc_service.initialize_oauth(app) +print("OIDC 服务初始化成功") # 创建 API 实例 api = Api( diff --git a/webapi/fastapi_of_letcoing/models/db_models.py b/webapi/fastapi_of_letcoing/models/db_models.py index 00474830040bda6bbf2fdc75bbcbf81aef4f26bc..026eeb28a53df6764fb4feee59bf7af54c5dadbe 100644 --- a/webapi/fastapi_of_letcoing/models/db_models.py +++ b/webapi/fastapi_of_letcoing/models/db_models.py @@ -83,11 +83,14 @@ class User(BaseModel): """用户模型""" id = AutoField(primary_key=True, verbose_name="用户ID") - username = CharField(max_length=50, unique=True, null=False, verbose_name="用户名") + username = CharField(max_length=50, unique=True, null=True, verbose_name="用户名") email = CharField(max_length=100, unique=True, null=True, verbose_name="邮箱") - password_hash = CharField(max_length=255, null=False, verbose_name="密码哈希") + password_hash = CharField(max_length=255, null=True, verbose_name="密码") is_active = BooleanField(default=True, verbose_name="是否激活") last_login = DateTimeField(null=True, verbose_name="最后登录时间") + provider = CharField(max_length=50, null=True, verbose_name="登录提供商") + provider_id = CharField(max_length=255, null=True, verbose_name="提供商用户ID") + avatar_url = CharField(max_length=500, null=True, verbose_name="头像URL") class Meta: table_name = "users" diff --git a/webapi/fastapi_of_letcoing/requirements.txt b/webapi/fastapi_of_letcoing/requirements.txt index f8faadb678421c27df89a361bf0ecadcf1262055..032e878f96951b73f06714612324309d514751d3 100644 --- a/webapi/fastapi_of_letcoing/requirements.txt +++ b/webapi/fastapi_of_letcoing/requirements.txt @@ -34,4 +34,4 @@ requests==2.32.5 rpds-py==0.30.0 urllib3==2.6.2 Werkzeug==3.1.4 -yarl==1.22.0 +yarl==1.22.0 \ No newline at end of file diff --git a/webapi/fastapi_of_letcoing/services/oidc_service.py b/webapi/fastapi_of_letcoing/services/oidc_service.py index b9b6297d0882853679d05507d259448209eb3df8..c3d2138bc4090eaf72b372e50a6e4595e48279dc 100644 --- a/webapi/fastapi_of_letcoing/services/oidc_service.py +++ b/webapi/fastapi_of_letcoing/services/oidc_service.py @@ -175,9 +175,19 @@ class OIDCService(Injectable, IOIDCService): 'provider': provider } else: - # 标准 OIDC 用户信息端点 + # 标准 OIDC 用户信息端点(支持 api.xauat.site/sso/signup 等自定义提供商) resp = client.get('userinfo', token=token) - return resp.json() + user_data = resp.json() + + # 标准化用户信息格式 + return { + 'id': user_data.get('sub') or user_data.get('id'), + 'username': user_data.get('preferred_username') or user_data.get('username') or user_data.get('name'), + 'name': user_data.get('name', ''), + 'email': user_data.get('email', ''), + 'avatar_url': user_data.get('picture') or user_data.get('avatar_url', ''), + 'provider': provider + } except Exception as ex: self._logger_service.error(f"获取用户信息失败: {provider}", ex) diff --git a/webapi/fastapi_of_letcoing/services/user_service.py b/webapi/fastapi_of_letcoing/services/user_service.py index 5c742d0a032389b9c9ee514e8b7fddb5604dac42..1a6f6458cfbbbae3e3dc2eef27660a3af5a5ef26 100644 --- a/webapi/fastapi_of_letcoing/services/user_service.py +++ b/webapi/fastapi_of_letcoing/services/user_service.py @@ -162,4 +162,75 @@ class UserService(DatabaseService, Injectable): except DoesNotExist: return None except Exception as e: - raise RuntimeError(f"获取用户认证信息时发生错误: {e}") \ No newline at end of file + raise RuntimeError(f"获取用户认证信息时发生错误: {e}") + + async def find_or_create_user(self, provider: str, provider_id: str, user_info: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据第三方登录信息查找或创建用户 + + Args: + provider: 登录提供商 (如 'github', 'api.xauat.site/sso/signup') + provider_id: 提供商的用户ID + user_info: 用户信息字典 + + Returns: + 用户信息字典 + """ + try: + # 尝试根据provider和provider_id查找用户 + try: + user = User.get( + (User.provider == provider) & + (User.provider_id == provider_id) + ) + + # 更新用户信息 + if user_info.get('username') and user.username != user_info['username']: + user.username = user_info['username'] + if user_info.get('email') and user.email != user_info['email']: + user.email = user_info['email'] + if user_info.get('avatar_url') and user.avatar_url != user_info['avatar_url']: + user.avatar_url = user_info['avatar_url'] + + user.last_login = datetime.now() + user.save() + + print(f"用户登录成功: {user.id} ({provider})") + return user.to_dict() + + except DoesNotExist: + # 用户不存在,创建新用户 + username = user_info.get('username') or f"{provider}_{provider_id}" + email = user_info.get('email') + + # 检查用户名是否已存在 + try: + existing_user = User.get(User.username == username) + username = f"{username}_{provider_id}" + except DoesNotExist: + pass + + # 检查邮箱是否已存在 + if email: + try: + existing_user = User.get(User.email == email) + email = None + except DoesNotExist: + pass + + user = User.create( + username=username, + email=email, + password_hash=None, + provider=provider, + provider_id=provider_id, + avatar_url=user_info.get('avatar_url'), + is_active=True, + last_login=datetime.now() + ) + + print(f"新用户注册成功: {user.id} ({provider})") + return user.to_dict() + + except Exception as e: + raise RuntimeError(f"查找或创建用户时发生错误: {e}") \ No newline at end of file diff --git a/webapi/fastapi_of_letcoing/simple_test.py b/webapi/fastapi_of_letcoing/simple_test.py new file mode 100644 index 0000000000000000000000000000000000000000..90eec64db750d290729970856a8a3bf8148764fb --- /dev/null +++ b/webapi/fastapi_of_letcoing/simple_test.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +print("Starting simple test...") + +# 测试基本导入 +try: + print("Importing Flask...") + from flask import Flask + print("Flask imported successfully!") +except Exception as e: + print(f"Error importing Flask: {e}") + +print("Test completed!")