From cdfe5fd282f63904d7050ade1d296964da4f3bd4 Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Fri, 19 Jun 2026 13:19:10 -0400 Subject: [PATCH] Add initial support assume role credential resolver --- .../aws/codegen/AwsAuthIntegration.java | 6 +- .../codegen/generators/SetupGenerator.java | 2 +- ...ture-cc329ec01bf84347ad01a29057da7f0d.json | 4 + packages/smithy-aws-core/pyproject.toml | 3 + .../_private/nested_clients/__init__.py | 2 + .../nested_clients/aws_sdk_sts/__init__.py | 3 + .../aws_sdk_sts/_private/__init__.py | 1 + .../aws_sdk_sts/_private/schemas.py | 465 ++++++++++ .../nested_clients/aws_sdk_sts/auth.py | 29 + .../nested_clients/aws_sdk_sts/client.py | 94 ++ .../nested_clients/aws_sdk_sts/config.py | 169 ++++ .../nested_clients/aws_sdk_sts/models.py | 812 ++++++++++++++++++ .../nested_clients/aws_sdk_sts/user_agent.py | 17 + .../src/smithy_aws_core/identity/__init__.py | 2 + .../src/smithy_aws_core/identity/sts.py | 138 +++ .../tests/unit/identity/test_sts.py | 286 ++++++ pyproject.toml | 6 +- uv.lock | 8 +- 18 files changed, 2041 insertions(+), 6 deletions(-) create mode 100644 packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-cc329ec01bf84347ad01a29057da7f0d.json create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/__init__.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/__init__.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/__init__.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/schemas.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/auth.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/client.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/config.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/models.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/user_agent.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/identity/sts.py create mode 100644 packages/smithy-aws-core/tests/unit/identity/test_sts.py diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java index 8358c54be..b717ebde4 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java @@ -50,7 +50,8 @@ public List getClientPlugins(GenerationContext context) { .namespace("smithy_core.aio.interfaces.identity", ".") .build()) .addReference(Symbol.builder() - .addDependency(AwsPythonDependency.SMITHY_AWS_CORE) + .addDependency(AwsPythonDependency.SMITHY_AWS_CORE + .withOptionalDependencies("assume-role")) .name("AWSCredentialsIdentity") .namespace("smithy_aws_core.identity", ".") .build()) @@ -152,7 +153,8 @@ public Symbol getAuthSchemeSymbol(GenerationContext context) { return Symbol.builder() .name("SigV4AuthScheme") .namespace("smithy_aws_core.auth", ".") - .addDependency(AwsPythonDependency.SMITHY_AWS_CORE) + .addDependency(AwsPythonDependency.SMITHY_AWS_CORE + .withOptionalDependencies("assume-role")) .build(); } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java index 6e88bae27..54010710c 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java @@ -219,7 +219,7 @@ private static List getOptionalDependencies(SymbolDependency dependency) }) .orElse(Collections.emptyList()); try { - return optionals; + return optionals.stream().sorted().toList(); } catch (Exception e) { return Collections.emptyList(); } diff --git a/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-cc329ec01bf84347ad01a29057da7f0d.json b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-cc329ec01bf84347ad01a29057da7f0d.json new file mode 100644 index 000000000..10ca1eb0d --- /dev/null +++ b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-cc329ec01bf84347ad01a29057da7f0d.json @@ -0,0 +1,4 @@ +{ + "type": "feature", + "description": "Added support for Assume Role credential resolution." +} \ No newline at end of file diff --git a/packages/smithy-aws-core/pyproject.toml b/packages/smithy-aws-core/pyproject.toml index 6a72200b0..f676faf89 100644 --- a/packages/smithy-aws-core/pyproject.toml +++ b/packages/smithy-aws-core/pyproject.toml @@ -54,6 +54,9 @@ json = [ xml = [ "smithy-xml~=0.1.0" ] +assume-role = [ + "smithy-aws-core[xml]" +] [tool.hatch.build] exclude = [ diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/__init__.py new file mode 100644 index 000000000..33cbe867a --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/__init__.py new file mode 100644 index 000000000..e1ee04907 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/__init__.py @@ -0,0 +1,3 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. + +__version__: str = "0.1.0" diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/__init__.py new file mode 100644 index 000000000..247be3e3d --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/__init__.py @@ -0,0 +1 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/schemas.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/schemas.py new file mode 100644 index 000000000..256603584 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/_private/schemas.py @@ -0,0 +1,465 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. + +from types import MappingProxyType + +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import Trait + + +ACCESS_KEY_ID_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#accessKeyIdType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 16, "max": 128}), + ), + Trait.new(id=ShapeID("smithy.api#pattern"), value="^[\\w]*$"), + ], +) + +ACCESS_KEY_SECRET_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#accessKeySecretType"), + shape_type=ShapeType.STRING, + traits=[Trait.new(id=ShapeID("smithy.api#sensitive"))], +) + +ARN_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#arnType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 20, "max": 2048}), + ), + Trait.new( + id=ShapeID("smithy.api#pattern"), + value="^[\\u0009\\u000A\\u000D\\u0020-\\u007E\\u0085\\u00A0-\\uD7FF\\uE000-\\uFFFD\\u10000-\\u10FFFF]+$", + ), + ], +) + +ASSUMED_ROLE_ID_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#assumedRoleIdType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 2, "max": 193}), + ), + Trait.new(id=ShapeID("smithy.api#pattern"), value="^[\\w+=,.@:-]*$"), + ], +) + +ASSUMED_ROLE_USER = Schema.collection( + id=ShapeID("com.amazonaws.sts#AssumedRoleUser"), + members={ + "AssumedRoleId": { + "target": ASSUMED_ROLE_ID_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + "Arn": { + "target": ARN_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + }, +) + +ROLE_DURATION_SECONDS_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#roleDurationSecondsType"), + shape_type=ShapeType.INTEGER, + traits=[ + Trait.new(id=ShapeID("smithy.api#box")), + Trait.new( + id=ShapeID("smithy.api#range"), + value=MappingProxyType({"min": 900, "max": 43200}), + ), + ], +) + +EXTERNAL_ID_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#externalIdType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 2, "max": 1224}), + ), + Trait.new(id=ShapeID("smithy.api#pattern"), value="^[\\w+=,.@:\\/-]*$"), + ], +) + +UNRESTRICTED_SESSION_POLICY_DOCUMENT_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#unrestrictedSessionPolicyDocumentType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new(id=ShapeID("smithy.api#length"), value=MappingProxyType({"min": 1})), + Trait.new( + id=ShapeID("smithy.api#pattern"), + value="^[\\u0009\\u000A\\u000D\\u0020-\\u00FF]+$", + ), + ], +) + +POLICY_DESCRIPTOR_TYPE = Schema.collection( + id=ShapeID("com.amazonaws.sts#PolicyDescriptorType"), + members={"arn": {"target": ARN_TYPE}}, +) + +POLICY_DESCRIPTOR_LIST_TYPE = Schema.collection( + id=ShapeID("com.amazonaws.sts#policyDescriptorListType"), + shape_type=ShapeType.LIST, + members={"member": {"target": POLICY_DESCRIPTOR_TYPE}}, +) + +CONTEXT_ASSERTION_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#contextAssertionType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 4, "max": 2048}), + ) + ], +) + +PROVIDED_CONTEXT = Schema.collection( + id=ShapeID("com.amazonaws.sts#ProvidedContext"), + members={ + "ProviderArn": {"target": ARN_TYPE}, + "ContextAssertion": {"target": CONTEXT_ASSERTION_TYPE}, + }, +) + +PROVIDED_CONTEXTS_LIST_TYPE = Schema.collection( + id=ShapeID("com.amazonaws.sts#ProvidedContextsListType"), + shape_type=ShapeType.LIST, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 1, "max": 5}), + ) + ], + members={"member": {"target": PROVIDED_CONTEXT}}, +) + +ROLE_SESSION_NAME_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#roleSessionNameType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 2, "max": 64}), + ), + Trait.new(id=ShapeID("smithy.api#pattern"), value="^[\\w+=,.@-]*$"), + ], +) + +SERIAL_NUMBER_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#serialNumberType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 9, "max": 256}), + ), + Trait.new(id=ShapeID("smithy.api#pattern"), value="^[\\w+=/:,.@-]*$"), + ], +) + +SOURCE_IDENTITY_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#sourceIdentityType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 2, "max": 64}), + ), + Trait.new(id=ShapeID("smithy.api#pattern"), value="^[\\w+=,.@-]*$"), + ], +) + +TAG_KEY_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#tagKeyType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 1, "max": 128}), + ), + Trait.new( + id=ShapeID("smithy.api#pattern"), value="^[\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+$" + ), + ], +) + +TAG_VALUE_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#tagValueType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 0, "max": 256}), + ), + Trait.new( + id=ShapeID("smithy.api#pattern"), value="^[\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]*$" + ), + ], +) + +TAG = Schema.collection( + id=ShapeID("com.amazonaws.sts#Tag"), + members={ + "Key": { + "target": TAG_KEY_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + "Value": { + "target": TAG_VALUE_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + }, +) + +TAG_LIST_TYPE = Schema.collection( + id=ShapeID("com.amazonaws.sts#tagListType"), + shape_type=ShapeType.LIST, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 0, "max": 50}), + ) + ], + members={"member": {"target": TAG}}, +) + +TOKEN_CODE_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#tokenCodeType"), + shape_type=ShapeType.STRING, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 6, "max": 6}), + ), + Trait.new(id=ShapeID("smithy.api#pattern"), value="^[\\d]*$"), + ], +) + +TAG_KEY_LIST_TYPE = Schema.collection( + id=ShapeID("com.amazonaws.sts#tagKeyListType"), + shape_type=ShapeType.LIST, + traits=[ + Trait.new( + id=ShapeID("smithy.api#length"), + value=MappingProxyType({"min": 0, "max": 50}), + ) + ], + members={"member": {"target": TAG_KEY_TYPE}}, +) + +ASSUME_ROLE_INPUT = Schema.collection( + id=ShapeID("com.amazonaws.sts#AssumeRoleInput"), + traits=[ + Trait.new( + id=ShapeID("smithy.synthetic#originalShapeId"), + value="com.amazonaws.sts#AssumeRoleRequest", + ), + Trait.new(id=ShapeID("smithy.api#input")), + ], + members={ + "RoleArn": { + "target": ARN_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + "RoleSessionName": { + "target": ROLE_SESSION_NAME_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + "PolicyArns": {"target": POLICY_DESCRIPTOR_LIST_TYPE}, + "Policy": {"target": UNRESTRICTED_SESSION_POLICY_DOCUMENT_TYPE}, + "DurationSeconds": {"target": ROLE_DURATION_SECONDS_TYPE}, + "Tags": {"target": TAG_LIST_TYPE}, + "TransitiveTagKeys": {"target": TAG_KEY_LIST_TYPE}, + "ExternalId": {"target": EXTERNAL_ID_TYPE}, + "SerialNumber": {"target": SERIAL_NUMBER_TYPE}, + "TokenCode": {"target": TOKEN_CODE_TYPE}, + "SourceIdentity": {"target": SOURCE_IDENTITY_TYPE}, + "ProvidedContexts": {"target": PROVIDED_CONTEXTS_LIST_TYPE}, + }, +) + +DATE_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#dateType"), shape_type=ShapeType.TIMESTAMP +) + +TOKEN_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#tokenType"), shape_type=ShapeType.STRING +) + +CREDENTIALS = Schema.collection( + id=ShapeID("com.amazonaws.sts#Credentials"), + members={ + "AccessKeyId": { + "target": ACCESS_KEY_ID_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + "SecretAccessKey": { + "target": ACCESS_KEY_SECRET_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + "SessionToken": { + "target": TOKEN_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + "Expiration": { + "target": DATE_TYPE, + "traits": [Trait.new(id=ShapeID("smithy.api#required"))], + }, + }, +) + +NON_NEGATIVE_INTEGER_TYPE = Schema( + id=ShapeID("com.amazonaws.sts#nonNegativeIntegerType"), + shape_type=ShapeType.INTEGER, + traits=[ + Trait.new(id=ShapeID("smithy.api#box")), + Trait.new(id=ShapeID("smithy.api#range"), value=MappingProxyType({"min": 0})), + ], +) + +ASSUME_ROLE_OUTPUT = Schema.collection( + id=ShapeID("com.amazonaws.sts#AssumeRoleOutput"), + traits=[ + Trait.new( + id=ShapeID("smithy.synthetic#originalShapeId"), + value="com.amazonaws.sts#AssumeRoleResponse", + ), + Trait.new(id=ShapeID("smithy.api#output")), + ], + members={ + "Credentials": {"target": CREDENTIALS}, + "AssumedRoleUser": {"target": ASSUMED_ROLE_USER}, + "PackedPolicySize": {"target": NON_NEGATIVE_INTEGER_TYPE}, + "SourceIdentity": {"target": SOURCE_IDENTITY_TYPE}, + }, +) + +EXPIRED_IDENTITY_TOKEN_MESSAGE = Schema( + id=ShapeID("com.amazonaws.sts#expiredIdentityTokenMessage"), + shape_type=ShapeType.STRING, +) + +EXPIRED_TOKEN_EXCEPTION = Schema.collection( + id=ShapeID("com.amazonaws.sts#ExpiredTokenException"), + traits=[ + Trait.new(id=ShapeID("smithy.api#error"), value="client"), + Trait.new(id=ShapeID("smithy.api#httpError"), value=400), + Trait.new( + id=ShapeID("aws.protocols#awsQueryError"), + value=MappingProxyType( + {"code": "ExpiredTokenException", "httpResponseCode": 400} + ), + ), + ], + members={"message": {"target": EXPIRED_IDENTITY_TOKEN_MESSAGE}}, +) + +MALFORMED_POLICY_DOCUMENT_MESSAGE = Schema( + id=ShapeID("com.amazonaws.sts#malformedPolicyDocumentMessage"), + shape_type=ShapeType.STRING, +) + +MALFORMED_POLICY_DOCUMENT_EXCEPTION = Schema.collection( + id=ShapeID("com.amazonaws.sts#MalformedPolicyDocumentException"), + traits=[ + Trait.new(id=ShapeID("smithy.api#error"), value="client"), + Trait.new(id=ShapeID("smithy.api#httpError"), value=400), + Trait.new( + id=ShapeID("aws.protocols#awsQueryError"), + value=MappingProxyType( + {"code": "MalformedPolicyDocument", "httpResponseCode": 400} + ), + ), + ], + members={"message": {"target": MALFORMED_POLICY_DOCUMENT_MESSAGE}}, +) + +PACKED_POLICY_TOO_LARGE_MESSAGE = Schema( + id=ShapeID("com.amazonaws.sts#packedPolicyTooLargeMessage"), + shape_type=ShapeType.STRING, +) + +PACKED_POLICY_TOO_LARGE_EXCEPTION = Schema.collection( + id=ShapeID("com.amazonaws.sts#PackedPolicyTooLargeException"), + traits=[ + Trait.new(id=ShapeID("smithy.api#error"), value="client"), + Trait.new(id=ShapeID("smithy.api#httpError"), value=400), + Trait.new( + id=ShapeID("aws.protocols#awsQueryError"), + value=MappingProxyType( + {"code": "PackedPolicyTooLarge", "httpResponseCode": 400} + ), + ), + ], + members={"message": {"target": PACKED_POLICY_TOO_LARGE_MESSAGE}}, +) + +REGION_DISABLED_MESSAGE = Schema( + id=ShapeID("com.amazonaws.sts#regionDisabledMessage"), shape_type=ShapeType.STRING +) + +REGION_DISABLED_EXCEPTION = Schema.collection( + id=ShapeID("com.amazonaws.sts#RegionDisabledException"), + traits=[ + Trait.new(id=ShapeID("smithy.api#error"), value="client"), + Trait.new(id=ShapeID("smithy.api#httpError"), value=403), + Trait.new( + id=ShapeID("aws.protocols#awsQueryError"), + value=MappingProxyType( + {"code": "RegionDisabledException", "httpResponseCode": 403} + ), + ), + ], + members={"message": {"target": REGION_DISABLED_MESSAGE}}, +) + +ASSUME_ROLE = Schema( + id=ShapeID("com.amazonaws.sts#AssumeRole"), shape_type=ShapeType.OPERATION +) + +AWS_SECURITY_TOKEN_SERVICE_V20110615 = Schema( + id=ShapeID("com.amazonaws.sts#AWSSecurityTokenServiceV20110615"), + shape_type=ShapeType.SERVICE, + traits=[ + Trait.new( + id=ShapeID("aws.auth#sigv4"), value=MappingProxyType({"name": "sts"}) + ), + Trait.new(id=ShapeID("smithy.api#title"), value="AWS Security Token Service"), + Trait.new( + id=ShapeID("aws.auth#sigv4a"), value=MappingProxyType({"name": "sts"}) + ), + Trait.new(id=ShapeID("aws.protocols#awsQuery")), + Trait.new( + id=ShapeID("smithy.api#auth"), value=("aws.auth#sigv4", "aws.auth#sigv4a") + ), + Trait.new( + id=ShapeID("aws.api#service"), + value=MappingProxyType( + { + "sdkId": "STS", + "arnNamespace": "sts", + "cloudFormationName": "STS", + "cloudTrailEventSource": "sts.amazonaws.com", + "endpointPrefix": "sts", + } + ), + ), + Trait.new( + id=ShapeID("smithy.api#xmlNamespace"), + value=MappingProxyType( + {"uri": "https://sts.amazonaws.com/doc/2011-06-15/"} + ), + ), + ], +) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/auth.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/auth.py new file mode 100644 index 000000000..fcc404301 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/auth.py @@ -0,0 +1,29 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. + +from typing import Any + +from smithy_core.auth import AuthOption, AuthParams +from smithy_core.interfaces.auth import AuthOption as AuthOptionProtocol +from smithy_core.shapes import ShapeID + + +class HTTPAuthSchemeResolver: + def resolve_auth_scheme( + self, auth_parameters: AuthParams[Any, Any] + ) -> list[AuthOptionProtocol]: + auth_options: list[AuthOptionProtocol] = [] + + if (option := _generate_sigv4_option(auth_parameters)) is not None: + auth_options.append(option) + + return auth_options + + +def _generate_sigv4_option( + auth_params: AuthParams[Any, Any], +) -> AuthOptionProtocol | None: + return AuthOption( + scheme_id=ShapeID("aws.auth#sigv4"), + identity_properties={}, # type: ignore + signer_properties={}, # type: ignore + ) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/client.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/client.py new file mode 100644 index 000000000..b7872ed44 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/client.py @@ -0,0 +1,94 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. + +from copy import deepcopy +import logging + +from smithy_core.aio.client import ClientCall, RequestPipeline +from smithy_core.exceptions import ExpectationNotMetError +from smithy_core.interceptors import InterceptorChain +from smithy_core.retries import RetryStrategyResolver +from smithy_core.types import TypedProperties +from smithy_http.plugins import user_agent_plugin + +from .config import Config, Plugin +from .models import ASSUME_ROLE, AssumeRoleInput, AssumeRoleOutput +from .user_agent import aws_user_agent_plugin + + +logger = logging.getLogger(__name__) + + +class STSClient: + """Client for AWSSecurityTokenServiceV20110615""" + + def __init__( + self, config: Config | None = None, plugins: list[Plugin] | None = None + ): + """ + Constructor for `STSClient`. + + Args: + config: + Optional configuration for the client. Here you can set things like + the endpoint for HTTP services or auth credentials. + plugins: + A list of callables that modify the configuration dynamically. These + can be used to set defaults, for example. + """ + self._config = config or Config() + + client_plugins: list[Plugin] = [aws_user_agent_plugin, user_agent_plugin] + if plugins: + client_plugins.extend(plugins) + + for plugin in client_plugins: + plugin(self._config) + + self._retry_strategy_resolver = RetryStrategyResolver() + + async def assume_role( + self, input: AssumeRoleInput, plugins: list[Plugin] | None = None + ) -> AssumeRoleOutput: + """ + Invokes the AssumeRole operation. + + Args: + input: + An instance of `AssumeRoleInput`. + plugins: + A list of callables that modify the configuration dynamically. + Changes made by these plugins only apply for the duration of the + operation execution and will not affect any other operation + invocations. + + Returns: + An instance of `AssumeRoleOutput`. + """ + operation_plugins: list[Plugin] = [] + if plugins: + operation_plugins.extend(plugins) + config = deepcopy(self._config) + for plugin in operation_plugins: + plugin(config) + if config.protocol is None or config.transport is None: + raise ExpectationNotMetError( + "protocol and transport MUST be set on the config to make calls." + ) + + retry_strategy = await self._retry_strategy_resolver.resolve_retry_strategy( + retry_strategy=config.retry_strategy + ) + + pipeline = RequestPipeline(protocol=config.protocol, transport=config.transport) + call = ClientCall( + input=input, + operation=ASSUME_ROLE, + context=TypedProperties({"config": config}), + interceptor=InterceptorChain(config.interceptors), + auth_scheme_resolver=config.auth_scheme_resolver, + supported_auth_schemes=config.auth_schemes, + endpoint_resolver=config.endpoint_resolver, + retry_strategy=retry_strategy, + ) + + return await pipeline(call) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/config.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/config.py new file mode 100644 index 000000000..8343b675a --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/config.py @@ -0,0 +1,169 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. + +from dataclasses import dataclass +from typing import Any, Callable, TypeAlias, Union + +from smithy_aws_core.aio.protocols import AwsQueryClientProtocol +from smithy_aws_core.auth import SigV4AuthScheme +from smithy_aws_core.endpoints.standard_regional import ( + StandardRegionalEndpointsResolver as _RegionalResolver, +) +from smithy_aws_core.identity import AWSCredentialsIdentity, AWSIdentityProperties +from smithy_core.aio.interfaces import ( + ClientProtocol, + ClientTransport, + EndpointResolver as _EndpointResolver, +) +from smithy_core.aio.interfaces.auth import AuthScheme +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.interceptors import Interceptor +from smithy_core.interfaces import URI +from smithy_core.interfaces.retries import RetryStrategy +from smithy_core.retries import RetryStrategyOptions +from smithy_core.shapes import ShapeID +from smithy_http.aio.crt import AWSCRTHTTPClient +from smithy_http.interfaces import HTTPRequestConfiguration + +from ._private.schemas import ( + AWS_SECURITY_TOKEN_SERVICE_V20110615 as _SCHEMA_AWS_SECURITY_TOKEN_SERVICE_V20110615, +) +from .auth import HTTPAuthSchemeResolver +from .models import AssumeRoleInput, AssumeRoleOutput + + +_ServiceInterceptor = Union[Interceptor[AssumeRoleInput, AssumeRoleOutput, Any, Any]] + + +@dataclass(init=False) +class Config: + """Configuration for STS.""" + + auth_scheme_resolver: HTTPAuthSchemeResolver + """ + An auth scheme resolver that determines the auth scheme for each + operation. + """ + + auth_schemes: dict[ShapeID, AuthScheme[Any, Any, Any, Any]] + """A map of auth scheme ids to auth schemes.""" + + aws_access_key_id: str | None + """The identifier for a secret access key.""" + + aws_credentials_identity_resolver: ( + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] | None + ) + """Resolves AWS Credentials. Required for operations that use Sigv4 Auth.""" + + aws_secret_access_key: str | None + """A secret access key that can be used to sign requests.""" + + aws_session_token: str | None + """An access key ID that identifies temporary security credentials.""" + + endpoint_resolver: _EndpointResolver + """ + The endpoint resolver used to resolve the final endpoint per-operation + based on the configuration. + """ + + endpoint_uri: str | URI | None + """A static URI to route requests to.""" + + http_request_config: HTTPRequestConfiguration | None + """Configuration for individual HTTP requests.""" + + interceptors: list[_ServiceInterceptor] + """ + The list of interceptors, which are hooks that are called during the + execution of a request. + """ + + protocol: ClientProtocol[Any, Any] | None + """The protocol to serialize and deserialize requests with.""" + + region: str | None + """ + The AWS region to connect to. The configured region is used to determine + the service endpoint. + """ + + retry_strategy: RetryStrategy | RetryStrategyOptions | None + """ + The retry strategy or options for configuring retry behavior. Can be + either a configured RetryStrategy or RetryStrategyOptions to create one. + """ + + sdk_ua_app_id: str | None + """ + A unique and opaque application ID that is appended to the User-Agent + header. + """ + + transport: ClientTransport[Any, Any] | None + """The transport to use to send requests (e.g. an HTTP client).""" + + user_agent_extra: str | None + """Additional suffix to be added to the User-Agent header.""" + + def __init__( + self, + *, + auth_scheme_resolver: HTTPAuthSchemeResolver | None = None, + auth_schemes: dict[ShapeID, AuthScheme[Any, Any, Any, Any]] | None = None, + aws_access_key_id: str | None = None, + aws_credentials_identity_resolver: IdentityResolver[ + AWSCredentialsIdentity, AWSIdentityProperties + ] + | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + endpoint_resolver: _EndpointResolver | None = None, + endpoint_uri: str | URI | None = None, + http_request_config: HTTPRequestConfiguration | None = None, + interceptors: list[_ServiceInterceptor] | None = None, + protocol: ClientProtocol[Any, Any] | None = None, + region: str | None = None, + retry_strategy: RetryStrategy | RetryStrategyOptions | None = None, + sdk_ua_app_id: str | None = None, + transport: ClientTransport[Any, Any] | None = None, + user_agent_extra: str | None = None, + ): + self.auth_scheme_resolver = auth_scheme_resolver or HTTPAuthSchemeResolver() + self.auth_schemes = auth_schemes or { + ShapeID("aws.auth#sigv4"): SigV4AuthScheme(service="sts") + } + self.aws_access_key_id = aws_access_key_id + self.aws_credentials_identity_resolver = aws_credentials_identity_resolver + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.endpoint_resolver = endpoint_resolver or _RegionalResolver( + endpoint_prefix="sts" + ) + self.endpoint_uri = endpoint_uri + self.http_request_config = http_request_config + self.interceptors = interceptors or [] + self.protocol = protocol or AwsQueryClientProtocol( + _SCHEMA_AWS_SECURITY_TOKEN_SERVICE_V20110615, "2011-06-15" + ) + self.region = region + self.retry_strategy = retry_strategy + self.sdk_ua_app_id = sdk_ua_app_id + self.transport = transport or AWSCRTHTTPClient() + self.user_agent_extra = user_agent_extra + + def set_auth_scheme(self, scheme: AuthScheme[Any, Any, Any, Any]) -> None: + """ + Sets the implementation of an auth scheme. + + Using this method ensures the correct key is used. + + Args: + scheme: + The auth scheme to add. + """ + self.auth_schemes[scheme.scheme_id] = scheme + + +Plugin: TypeAlias = Callable[[Config], None] +"""A callable that allows customizing the config object on each request.""" diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/models.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/models.py new file mode 100644 index 000000000..db973fe60 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/models.py @@ -0,0 +1,812 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. + +from dataclasses import dataclass, field +from datetime import datetime +import logging +from typing import Any, Literal, Self + +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.documents import TypeRegistry +from smithy_core.exceptions import ModeledError +from smithy_core.schemas import APIOperation, Schema +from smithy_core.serializers import ShapeSerializer +from smithy_core.shapes import ShapeID + +from ._private.schemas import ( + ASSUMED_ROLE_USER as _SCHEMA_ASSUMED_ROLE_USER, + ASSUME_ROLE as _SCHEMA_ASSUME_ROLE, + ASSUME_ROLE_INPUT as _SCHEMA_ASSUME_ROLE_INPUT, + ASSUME_ROLE_OUTPUT as _SCHEMA_ASSUME_ROLE_OUTPUT, + CREDENTIALS as _SCHEMA_CREDENTIALS, + EXPIRED_TOKEN_EXCEPTION as _SCHEMA_EXPIRED_TOKEN_EXCEPTION, + MALFORMED_POLICY_DOCUMENT_EXCEPTION as _SCHEMA_MALFORMED_POLICY_DOCUMENT_EXCEPTION, + PACKED_POLICY_TOO_LARGE_EXCEPTION as _SCHEMA_PACKED_POLICY_TOO_LARGE_EXCEPTION, + POLICY_DESCRIPTOR_TYPE as _SCHEMA_POLICY_DESCRIPTOR_TYPE, + PROVIDED_CONTEXT as _SCHEMA_PROVIDED_CONTEXT, + REGION_DISABLED_EXCEPTION as _SCHEMA_REGION_DISABLED_EXCEPTION, + TAG as _SCHEMA_TAG, +) + + +logger = logging.getLogger(__name__) + + +class ServiceError(ModeledError): + """ + Base error for all errors in the service. + + Some exceptions do not extend from this class, including + synthetic, implicit, and shared exception types. + """ + + +@dataclass(kw_only=True) +class AssumedRoleUser: + """Dataclass for AssumedRoleUser structure.""" + + assumed_role_id: str + + arn: str + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_ASSUMED_ROLE_USER, self) + + def serialize_members(self, serializer: ShapeSerializer): + serializer.write_string( + _SCHEMA_ASSUMED_ROLE_USER.members["AssumedRoleId"], self.assumed_role_id + ) + serializer.write_string(_SCHEMA_ASSUMED_ROLE_USER.members["Arn"], self.arn) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["assumed_role_id"] = de.read_string( + _SCHEMA_ASSUMED_ROLE_USER.members["AssumedRoleId"] + ) + + case 1: + kwargs["arn"] = de.read_string( + _SCHEMA_ASSUMED_ROLE_USER.members["Arn"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_ASSUMED_ROLE_USER, consumer=_consumer) + return kwargs + + +@dataclass(kw_only=True) +class PolicyDescriptorType: + """Dataclass for PolicyDescriptorType structure.""" + + arn: str | None = None + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_POLICY_DESCRIPTOR_TYPE, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.arn is not None: + serializer.write_string( + _SCHEMA_POLICY_DESCRIPTOR_TYPE.members["arn"], self.arn + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["arn"] = de.read_string( + _SCHEMA_POLICY_DESCRIPTOR_TYPE.members["arn"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_POLICY_DESCRIPTOR_TYPE, consumer=_consumer) + return kwargs + + +def _serialize_policy_descriptor_list_type( + serializer: ShapeSerializer, schema: Schema, value: list[PolicyDescriptorType] +) -> None: + member_schema = schema.members["member"] + with serializer.begin_list(schema, len(value)) as ls: + for e in value: + ls.write_struct(member_schema, e) + + +def _deserialize_policy_descriptor_list_type( + deserializer: ShapeDeserializer, schema: Schema +) -> list[PolicyDescriptorType]: + result: list[PolicyDescriptorType] = [] + + def _read_value(d: ShapeDeserializer): + if d.is_null(): + d.read_null() + + else: + result.append(PolicyDescriptorType.deserialize(d)) + + deserializer.read_list(schema, _read_value) + return result + + +@dataclass(kw_only=True) +class ProvidedContext: + """Dataclass for ProvidedContext structure.""" + + provider_arn: str | None = None + + context_assertion: str | None = None + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_PROVIDED_CONTEXT, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.provider_arn is not None: + serializer.write_string( + _SCHEMA_PROVIDED_CONTEXT.members["ProviderArn"], self.provider_arn + ) + + if self.context_assertion is not None: + serializer.write_string( + _SCHEMA_PROVIDED_CONTEXT.members["ContextAssertion"], + self.context_assertion, + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["provider_arn"] = de.read_string( + _SCHEMA_PROVIDED_CONTEXT.members["ProviderArn"] + ) + + case 1: + kwargs["context_assertion"] = de.read_string( + _SCHEMA_PROVIDED_CONTEXT.members["ContextAssertion"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_PROVIDED_CONTEXT, consumer=_consumer) + return kwargs + + +def _serialize_provided_contexts_list_type( + serializer: ShapeSerializer, schema: Schema, value: list[ProvidedContext] +) -> None: + member_schema = schema.members["member"] + with serializer.begin_list(schema, len(value)) as ls: + for e in value: + ls.write_struct(member_schema, e) + + +def _deserialize_provided_contexts_list_type( + deserializer: ShapeDeserializer, schema: Schema +) -> list[ProvidedContext]: + result: list[ProvidedContext] = [] + + def _read_value(d: ShapeDeserializer): + if d.is_null(): + d.read_null() + + else: + result.append(ProvidedContext.deserialize(d)) + + deserializer.read_list(schema, _read_value) + return result + + +@dataclass(kw_only=True) +class Tag: + """Dataclass for Tag structure.""" + + key: str + + value: str + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_TAG, self) + + def serialize_members(self, serializer: ShapeSerializer): + serializer.write_string(_SCHEMA_TAG.members["Key"], self.key) + serializer.write_string(_SCHEMA_TAG.members["Value"], self.value) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["key"] = de.read_string(_SCHEMA_TAG.members["Key"]) + + case 1: + kwargs["value"] = de.read_string(_SCHEMA_TAG.members["Value"]) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_TAG, consumer=_consumer) + return kwargs + + +def _serialize_tag_list_type( + serializer: ShapeSerializer, schema: Schema, value: list[Tag] +) -> None: + member_schema = schema.members["member"] + with serializer.begin_list(schema, len(value)) as ls: + for e in value: + ls.write_struct(member_schema, e) + + +def _deserialize_tag_list_type( + deserializer: ShapeDeserializer, schema: Schema +) -> list[Tag]: + result: list[Tag] = [] + + def _read_value(d: ShapeDeserializer): + if d.is_null(): + d.read_null() + + else: + result.append(Tag.deserialize(d)) + + deserializer.read_list(schema, _read_value) + return result + + +def _serialize_tag_key_list_type( + serializer: ShapeSerializer, schema: Schema, value: list[str] +) -> None: + member_schema = schema.members["member"] + with serializer.begin_list(schema, len(value)) as ls: + for e in value: + ls.write_string(member_schema, e) + + +def _deserialize_tag_key_list_type( + deserializer: ShapeDeserializer, schema: Schema +) -> list[str]: + result: list[str] = [] + member_schema = schema.members["member"] + + def _read_value(d: ShapeDeserializer): + if d.is_null(): + d.read_null() + + else: + result.append(d.read_string(member_schema)) + + deserializer.read_list(schema, _read_value) + return result + + +@dataclass(kw_only=True) +class AssumeRoleInput: + """Dataclass for AssumeRoleInput structure.""" + + role_arn: str | None = None + + role_session_name: str | None = None + + policy_arns: list[PolicyDescriptorType] | None = None + + policy: str | None = None + + duration_seconds: int | None = None + + tags: list[Tag] | None = None + + transitive_tag_keys: list[str] | None = None + + external_id: str | None = None + + serial_number: str | None = None + + token_code: str | None = None + + source_identity: str | None = None + + provided_contexts: list[ProvidedContext] | None = None + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_ASSUME_ROLE_INPUT, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.role_arn is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["RoleArn"], self.role_arn + ) + + if self.role_session_name is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["RoleSessionName"], + self.role_session_name, + ) + + if self.policy_arns is not None: + _serialize_policy_descriptor_list_type( + serializer, + _SCHEMA_ASSUME_ROLE_INPUT.members["PolicyArns"], + self.policy_arns, + ) + + if self.policy is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["Policy"], self.policy + ) + + if self.duration_seconds is not None: + serializer.write_integer( + _SCHEMA_ASSUME_ROLE_INPUT.members["DurationSeconds"], + self.duration_seconds, + ) + + if self.tags is not None: + _serialize_tag_list_type( + serializer, _SCHEMA_ASSUME_ROLE_INPUT.members["Tags"], self.tags + ) + + if self.transitive_tag_keys is not None: + _serialize_tag_key_list_type( + serializer, + _SCHEMA_ASSUME_ROLE_INPUT.members["TransitiveTagKeys"], + self.transitive_tag_keys, + ) + + if self.external_id is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["ExternalId"], self.external_id + ) + + if self.serial_number is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["SerialNumber"], self.serial_number + ) + + if self.token_code is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["TokenCode"], self.token_code + ) + + if self.source_identity is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["SourceIdentity"], + self.source_identity, + ) + + if self.provided_contexts is not None: + _serialize_provided_contexts_list_type( + serializer, + _SCHEMA_ASSUME_ROLE_INPUT.members["ProvidedContexts"], + self.provided_contexts, + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["role_arn"] = de.read_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["RoleArn"] + ) + + case 1: + kwargs["role_session_name"] = de.read_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["RoleSessionName"] + ) + + case 2: + kwargs["policy_arns"] = _deserialize_policy_descriptor_list_type( + de, _SCHEMA_ASSUME_ROLE_INPUT.members["PolicyArns"] + ) + + case 3: + kwargs["policy"] = de.read_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["Policy"] + ) + + case 4: + kwargs["duration_seconds"] = de.read_integer( + _SCHEMA_ASSUME_ROLE_INPUT.members["DurationSeconds"] + ) + + case 5: + kwargs["tags"] = _deserialize_tag_list_type( + de, _SCHEMA_ASSUME_ROLE_INPUT.members["Tags"] + ) + + case 6: + kwargs["transitive_tag_keys"] = _deserialize_tag_key_list_type( + de, _SCHEMA_ASSUME_ROLE_INPUT.members["TransitiveTagKeys"] + ) + + case 7: + kwargs["external_id"] = de.read_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["ExternalId"] + ) + + case 8: + kwargs["serial_number"] = de.read_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["SerialNumber"] + ) + + case 9: + kwargs["token_code"] = de.read_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["TokenCode"] + ) + + case 10: + kwargs["source_identity"] = de.read_string( + _SCHEMA_ASSUME_ROLE_INPUT.members["SourceIdentity"] + ) + + case 11: + kwargs["provided_contexts"] = ( + _deserialize_provided_contexts_list_type( + de, _SCHEMA_ASSUME_ROLE_INPUT.members["ProvidedContexts"] + ) + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_ASSUME_ROLE_INPUT, consumer=_consumer) + return kwargs + + +@dataclass(kw_only=True) +class Credentials: + """Dataclass for Credentials structure.""" + + access_key_id: str + + secret_access_key: str = field(repr=False) + + session_token: str + + expiration: datetime + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_CREDENTIALS, self) + + def serialize_members(self, serializer: ShapeSerializer): + serializer.write_string( + _SCHEMA_CREDENTIALS.members["AccessKeyId"], self.access_key_id + ) + serializer.write_string( + _SCHEMA_CREDENTIALS.members["SecretAccessKey"], self.secret_access_key + ) + serializer.write_string( + _SCHEMA_CREDENTIALS.members["SessionToken"], self.session_token + ) + serializer.write_timestamp( + _SCHEMA_CREDENTIALS.members["Expiration"], self.expiration + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["access_key_id"] = de.read_string( + _SCHEMA_CREDENTIALS.members["AccessKeyId"] + ) + + case 1: + kwargs["secret_access_key"] = de.read_string( + _SCHEMA_CREDENTIALS.members["SecretAccessKey"] + ) + + case 2: + kwargs["session_token"] = de.read_string( + _SCHEMA_CREDENTIALS.members["SessionToken"] + ) + + case 3: + kwargs["expiration"] = de.read_timestamp( + _SCHEMA_CREDENTIALS.members["Expiration"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_CREDENTIALS, consumer=_consumer) + return kwargs + + +@dataclass(kw_only=True) +class AssumeRoleOutput: + """Dataclass for AssumeRoleOutput structure.""" + + credentials: Credentials | None = None + + assumed_role_user: AssumedRoleUser | None = None + + packed_policy_size: int | None = None + + source_identity: str | None = None + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_ASSUME_ROLE_OUTPUT, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.credentials is not None: + serializer.write_struct( + _SCHEMA_ASSUME_ROLE_OUTPUT.members["Credentials"], self.credentials + ) + + if self.assumed_role_user is not None: + serializer.write_struct( + _SCHEMA_ASSUME_ROLE_OUTPUT.members["AssumedRoleUser"], + self.assumed_role_user, + ) + + if self.packed_policy_size is not None: + serializer.write_integer( + _SCHEMA_ASSUME_ROLE_OUTPUT.members["PackedPolicySize"], + self.packed_policy_size, + ) + + if self.source_identity is not None: + serializer.write_string( + _SCHEMA_ASSUME_ROLE_OUTPUT.members["SourceIdentity"], + self.source_identity, + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["credentials"] = Credentials.deserialize(de) + + case 1: + kwargs["assumed_role_user"] = AssumedRoleUser.deserialize(de) + + case 2: + kwargs["packed_policy_size"] = de.read_integer( + _SCHEMA_ASSUME_ROLE_OUTPUT.members["PackedPolicySize"] + ) + + case 3: + kwargs["source_identity"] = de.read_string( + _SCHEMA_ASSUME_ROLE_OUTPUT.members["SourceIdentity"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_ASSUME_ROLE_OUTPUT, consumer=_consumer) + return kwargs + + +@dataclass(kw_only=True) +class ExpiredTokenException(ServiceError): + """Dataclass for ExpiredTokenException structure.""" + + fault: Literal["client", "server"] | None = "client" + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_EXPIRED_TOKEN_EXCEPTION, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.message is not None: + serializer.write_string( + _SCHEMA_EXPIRED_TOKEN_EXCEPTION.members["message"], self.message + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["message"] = de.read_string( + _SCHEMA_EXPIRED_TOKEN_EXCEPTION.members["message"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_EXPIRED_TOKEN_EXCEPTION, consumer=_consumer) + return kwargs + + +@dataclass(kw_only=True) +class MalformedPolicyDocumentException(ServiceError): + """Dataclass for MalformedPolicyDocumentException structure.""" + + fault: Literal["client", "server"] | None = "client" + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_MALFORMED_POLICY_DOCUMENT_EXCEPTION, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.message is not None: + serializer.write_string( + _SCHEMA_MALFORMED_POLICY_DOCUMENT_EXCEPTION.members["message"], + self.message, + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["message"] = de.read_string( + _SCHEMA_MALFORMED_POLICY_DOCUMENT_EXCEPTION.members["message"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct( + _SCHEMA_MALFORMED_POLICY_DOCUMENT_EXCEPTION, consumer=_consumer + ) + return kwargs + + +@dataclass(kw_only=True) +class PackedPolicyTooLargeException(ServiceError): + """Dataclass for PackedPolicyTooLargeException structure.""" + + fault: Literal["client", "server"] | None = "client" + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_PACKED_POLICY_TOO_LARGE_EXCEPTION, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.message is not None: + serializer.write_string( + _SCHEMA_PACKED_POLICY_TOO_LARGE_EXCEPTION.members["message"], + self.message, + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["message"] = de.read_string( + _SCHEMA_PACKED_POLICY_TOO_LARGE_EXCEPTION.members["message"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct( + _SCHEMA_PACKED_POLICY_TOO_LARGE_EXCEPTION, consumer=_consumer + ) + return kwargs + + +@dataclass(kw_only=True) +class RegionDisabledException(ServiceError): + """Dataclass for RegionDisabledException structure.""" + + fault: Literal["client", "server"] | None = "client" + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(_SCHEMA_REGION_DISABLED_EXCEPTION, self) + + def serialize_members(self, serializer: ShapeSerializer): + if self.message is not None: + serializer.write_string( + _SCHEMA_REGION_DISABLED_EXCEPTION.members["message"], self.message + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(**cls.deserialize_kwargs(deserializer)) + + @classmethod + def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["message"] = de.read_string( + _SCHEMA_REGION_DISABLED_EXCEPTION.members["message"] + ) + + case _: + logger.debug("Unexpected member schema: %s", schema) + + deserializer.read_struct(_SCHEMA_REGION_DISABLED_EXCEPTION, consumer=_consumer) + return kwargs + + +ASSUME_ROLE = APIOperation( + input=AssumeRoleInput, + output=AssumeRoleOutput, + schema=_SCHEMA_ASSUME_ROLE, + input_schema=_SCHEMA_ASSUME_ROLE_INPUT, + output_schema=_SCHEMA_ASSUME_ROLE_OUTPUT, + error_registry=TypeRegistry( + { + ShapeID("com.amazonaws.sts#ExpiredTokenException"): ExpiredTokenException, + ShapeID( + "com.amazonaws.sts#MalformedPolicyDocumentException" + ): MalformedPolicyDocumentException, + ShapeID( + "com.amazonaws.sts#PackedPolicyTooLargeException" + ): PackedPolicyTooLargeException, + ShapeID( + "com.amazonaws.sts#RegionDisabledException" + ): RegionDisabledException, + } + ), + effective_auth_schemes=[ShapeID("aws.auth#sigv4"), ShapeID("aws.auth#sigv4a")], + error_schemas=[ + _SCHEMA_EXPIRED_TOKEN_EXCEPTION, + _SCHEMA_MALFORMED_POLICY_DOCUMENT_EXCEPTION, + _SCHEMA_PACKED_POLICY_TOO_LARGE_EXCEPTION, + _SCHEMA_REGION_DISABLED_EXCEPTION, + ], +) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/user_agent.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/user_agent.py new file mode 100644 index 000000000..34c32fbbc --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/aws_sdk_sts/user_agent.py @@ -0,0 +1,17 @@ +# Code generated by smithy-python-codegen DO NOT EDIT. + +from smithy_aws_core.interceptors.user_agent import UserAgentInterceptor + +from . import __version__ +from .config import Config + + +def aws_user_agent_plugin(config: Config): + config.interceptors.append( + UserAgentInterceptor( + ua_suffix=config.user_agent_extra, + ua_app_id=config.sdk_ua_app_id, + sdk_version=__version__, + service_id="STS", + ) + ) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/identity/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/__init__.py index 1b310e5bd..87d136d2d 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/identity/__init__.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/__init__.py @@ -12,11 +12,13 @@ from .environment import EnvironmentCredentialsResolver from .imds import IMDSCredentialsResolver from .static import StaticCredentialsResolver +from .sts import AssumeRoleCredentialsResolver __all__ = ( "AWSCredentialsIdentity", "AWSCredentialsResolver", "AWSIdentityProperties", + "AssumeRoleCredentialsResolver", "ContainerCredentialsResolver", "EnvironmentCredentialsResolver", "IMDSCredentialsResolver", diff --git a/packages/smithy-aws-core/src/smithy_aws_core/identity/sts.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/sts.py new file mode 100644 index 000000000..c3f69af31 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/sts.py @@ -0,0 +1,138 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import uuid +from collections.abc import Awaitable, Callable +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any + +from smithy_core.aio.interfaces import ClientTransport +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.exceptions import SmithyIdentityError + +from .components import ( + AWSCredentialsIdentity, + AWSCredentialsResolver, + AWSIdentityProperties, +) + +if TYPE_CHECKING: + from .._private.nested_clients.aws_sdk_sts.client import STSClient + +DEFAULT_STS_REGION = "us-east-1" + +type MfaCodeProvider = Callable[[str], Awaitable[str]] +"""An async callback to provide MFA token codes. + +It receives the MFA device's serial number and returns the current token code +(e.g. from a prompt, TOTP generator, or secrets store). It must return a fresh, +single-use code each time. +""" + + +def _account_id_from_arn(arn: str | None) -> str | None: + """Extract account ID from an ARN.""" + if arn is None: + return None + parts = arn.split(":") + if len(parts) < 5 or not parts[4]: + return None + return parts[4] + + +class AssumeRoleCredentialsResolver( + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] +): + """Resolves AWS credentials from an STS ``AssumeRole`` call.""" + + def __init__( + self, + source_resolver: AWSCredentialsResolver, + role_arn: str, + role_session_name: str | None = None, + external_id: str | None = None, + duration_seconds: int | None = None, + region: str | None = None, + http_client: ClientTransport[Any, Any] | None = None, + mfa_serial: str | None = None, + mfa_code_provider: MfaCodeProvider | None = None, + ) -> None: + self._source_resolver = source_resolver + self._role_arn = role_arn + self._role_session_name = ( + role_session_name or f"aws-sdk-python-{uuid.uuid4().hex[:8]}" + ) + self._external_id = external_id + self._duration_seconds = duration_seconds + self._region = region or DEFAULT_STS_REGION + self._http_client = http_client + if mfa_serial is not None and mfa_code_provider is None: + raise ValueError("mfa_code_provider is required when mfa_serial is set.") + self._mfa_serial = mfa_serial + self._mfa_code_provider = mfa_code_provider + self._credentials: AWSCredentialsIdentity | None = None + self._sts_client: STSClient | None = None + self._refresh_lock = asyncio.Lock() + + async def get_identity( + self, *, properties: AWSIdentityProperties + ) -> AWSCredentialsIdentity: + if self._credentials is not None and self._is_fresh(self._credentials): + return self._credentials + async with self._refresh_lock: + if self._credentials is not None and self._is_fresh(self._credentials): + return self._credentials + self._credentials = await self._call_assume_role() + return self._credentials + + def _is_fresh(self, credentials: AWSCredentialsIdentity) -> bool: + return ( + credentials.expiration is not None + and datetime.now(UTC) < credentials.expiration + ) + + async def _call_assume_role(self) -> AWSCredentialsIdentity: + # Lazy import to avoid a circular import during module initialization + from .._private.nested_clients.aws_sdk_sts.client import STSClient + from .._private.nested_clients.aws_sdk_sts.config import Config + from .._private.nested_clients.aws_sdk_sts.models import AssumeRoleInput + + if self._sts_client is None: + self._sts_client = STSClient( + config=Config( + region=self._region, + aws_credentials_identity_resolver=self._source_resolver, + transport=self._http_client, + ) + ) + + token_code = None + if self._mfa_serial is not None and self._mfa_code_provider is not None: + token_code = await self._mfa_code_provider(self._mfa_serial) + + response = await self._sts_client.assume_role( + AssumeRoleInput( + role_arn=self._role_arn, + role_session_name=self._role_session_name, + duration_seconds=self._duration_seconds, + external_id=self._external_id, + serial_number=self._mfa_serial, + token_code=token_code, + ) + ) + + creds = response.credentials + if creds is None: + raise SmithyIdentityError("STS AssumeRole response missing Credentials") + + account_id = None + if response.assumed_role_user is not None: + account_id = _account_id_from_arn(response.assumed_role_user.arn) + + return AWSCredentialsIdentity( + access_key_id=creds.access_key_id, + secret_access_key=creds.secret_access_key, + session_token=creds.session_token, + expiration=creds.expiration, + account_id=account_id, + ) diff --git a/packages/smithy-aws-core/tests/unit/identity/test_sts.py b/packages/smithy-aws-core/tests/unit/identity/test_sts.py new file mode 100644 index 000000000..290b3bccf --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/identity/test_sts.py @@ -0,0 +1,286 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pyright: reportPrivateUsage=false +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock + +import pytest +from smithy_aws_core._private.nested_clients.aws_sdk_sts.models import ( + AssumedRoleUser, + AssumeRoleOutput, + Credentials, +) +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver +from smithy_aws_core.identity.sts import ( + AssumeRoleCredentialsResolver, + _account_id_from_arn, +) +from smithy_core.exceptions import SmithyIdentityError +from smithy_http.testing import MockHTTPClient + +ROLE_ARN = "arn:aws:iam::123456789012:role/MyRole" +ASSUMED_ROLE_ARN = "arn:aws:sts::123456789012:assumed-role/MyRole/session" +MFA_SERIAL = "arn:aws:iam::123456789012:mfa/device" +ACCESS_KEY_ID = "test-access-key" +SECRET_ACCESS_KEY = "test-secret-key" +SESSION_TOKEN = "test-session-token" + + +def _future_expiry() -> datetime: + return datetime.now(UTC) + timedelta(hours=1) + + +def _past_expiry() -> datetime: + return datetime.now(UTC) - timedelta(hours=1) + + +def _valid_output( + *, access_key_id: str = ACCESS_KEY_ID, expiration: datetime | None = None +) -> AssumeRoleOutput: + """An AssumeRole response with valid credentials and assumed-role user""" + return AssumeRoleOutput( + credentials=Credentials( + access_key_id=access_key_id, + secret_access_key=SECRET_ACCESS_KEY, + session_token=SESSION_TOKEN, + expiration=expiration or _future_expiry(), + ), + assumed_role_user=AssumedRoleUser(assumed_role_id="id", arn=ASSUMED_ROLE_ARN), + ) + + +def _mock_sts_client( + resolver: AssumeRoleCredentialsResolver, *responses: AssumeRoleOutput +) -> AsyncMock: + """Attach a mock STS client returning one response per AssumeRole call""" + client = AsyncMock() + client.assume_role.side_effect = list(responses) + resolver._sts_client = client + return client + + +def _assume_role_response_body() -> bytes: + return ( + "" + "" + "sts-akid" + "sts-secret" + "sts-token" + "2030-01-01T00:00:00Z" + "" + "" + "id:session" + f"{ASSUMED_ROLE_ARN}" + "" + "" + ).encode() + + +@pytest.mark.parametrize( + "arn,expected", + [ + (ASSUMED_ROLE_ARN, "123456789012"), + ("arn:aws:sts:::assumed-role/MyRole/session", None), # empty account field + ("not-an-arn", None), # too few segments + (None, None), + ], +) +def test_account_id_from_arn(arn: str | None, expected: str | None): + assert _account_id_from_arn(arn) == expected + + +async def test_resolves_identity_from_assume_role(): + expiration = _future_expiry() + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), role_arn=ROLE_ARN + ) + _mock_sts_client(resolver, _valid_output(expiration=expiration)) + + identity = await resolver.get_identity(properties={}) + + assert identity.access_key_id == ACCESS_KEY_ID + assert identity.secret_access_key == SECRET_ACCESS_KEY + assert identity.session_token == SESSION_TOKEN + assert identity.expiration == expiration + assert identity.account_id == "123456789012" + + +async def test_assume_role_signed_with_source_credentials( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "source-akid") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "source-secret") + + http_client = MockHTTPClient() + http_client.add_response(body=_assume_role_response_body()) + + resolver = AssumeRoleCredentialsResolver( + source_resolver=EnvironmentCredentialsResolver(), + role_arn=ROLE_ARN, + region="us-east-1", + http_client=http_client, + ) + + identity = await resolver.get_identity(properties={}) + + # The resolved identity is the STS-issued credential + assert identity.access_key_id == "sts-akid" + assert identity.session_token == "sts-token" + + # The request was signed with the source credentials + [request] = http_client.captured_requests + authorization = request.fields["Authorization"].as_string() + assert "Credential=source-akid/" in authorization + + +async def test_missing_credentials_raises(): + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), role_arn=ROLE_ARN + ) + _mock_sts_client(resolver, AssumeRoleOutput(credentials=None)) + + with pytest.raises( + SmithyIdentityError, match="STS AssumeRole response missing Credentials" + ): + await resolver.get_identity(properties={}) + + +async def test_valid_credentials_reused(): + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), role_arn=ROLE_ARN + ) + sts_client = _mock_sts_client( + resolver, + _valid_output(access_key_id="test-access-key-1"), + _valid_output(access_key_id="test-access-key-2"), + ) + + identity_one = await resolver.get_identity(properties={}) + identity_two = await resolver.get_identity(properties={}) + + # The cached identity is returned without a second STS call + assert identity_one is identity_two + assert sts_client.assume_role.call_count == 1 + + +async def test_expired_credentials_refreshed(): + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), role_arn=ROLE_ARN + ) + sts_client = _mock_sts_client( + resolver, + _valid_output(access_key_id="test-access-key-1", expiration=_past_expiry()), + _valid_output(access_key_id="test-access-key-2"), + ) + + identity_one = await resolver.get_identity(properties={}) + identity_two = await resolver.get_identity(properties={}) + + # The cached identity is refreshed with a second STS call + assert identity_one is not identity_two + assert identity_one.access_key_id == "test-access-key-1" + assert identity_two.access_key_id == "test-access-key-2" + assert sts_client.assume_role.call_count == 2 + + +async def test_assume_role_request_uses_settings(): + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), + role_arn=ROLE_ARN, + role_session_name="test-session-name", + external_id="test-external-id", + duration_seconds=1000, + ) + sts_client = _mock_sts_client(resolver, _valid_output()) + + await resolver.get_identity(properties={}) + + request = sts_client.assume_role.call_args.args[0] + assert request.role_arn == ROLE_ARN + assert request.role_session_name == "test-session-name" + assert request.external_id == "test-external-id" + assert request.duration_seconds == 1000 + + +async def test_role_session_name_generated_when_unset(): + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), role_arn=ROLE_ARN + ) + sts_client = _mock_sts_client(resolver, _valid_output()) + + await resolver.get_identity(properties={}) + + request = sts_client.assume_role.call_args.args[0] + assert request.role_session_name.startswith("aws-sdk-python-") + + +async def test_role_session_name_stable_across_refreshes(): + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), role_arn=ROLE_ARN + ) + sts_client = _mock_sts_client( + resolver, + _valid_output(expiration=_past_expiry()), + _valid_output(), + ) + + await resolver.get_identity(properties={}) + await resolver.get_identity(properties={}) + + first, second = sts_client.assume_role.call_args_list + assert first.args[0].role_session_name == second.args[0].role_session_name + + +async def test_mfa_serial_and_token_code_sent(): + code_provider = AsyncMock(return_value="111111") + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), + role_arn=ROLE_ARN, + mfa_serial=MFA_SERIAL, + mfa_code_provider=code_provider, + ) + sts_client = _mock_sts_client(resolver, _valid_output()) + + await resolver.get_identity(properties={}) + + request = sts_client.assume_role.call_args.args[0] + assert request.serial_number == MFA_SERIAL + assert request.token_code == "111111" + code_provider.assert_awaited_once_with(MFA_SERIAL) + + +async def test_mfa_code_provider_invoked_on_each_refresh(): + code_provider = AsyncMock(side_effect=["111111", "222222"]) + resolver = AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), + role_arn=ROLE_ARN, + mfa_serial=MFA_SERIAL, + mfa_code_provider=code_provider, + ) + sts_client = _mock_sts_client( + resolver, + _valid_output(expiration=_past_expiry()), + _valid_output(), + ) + + await resolver.get_identity(properties={}) + await resolver.get_identity(properties={}) + + # A fresh, single-use code is fetched for each assume call + assert code_provider.await_count == 2 + first, second = sts_client.assume_role.call_args_list + assert first.args[0].token_code == "111111" + assert second.args[0].token_code == "222222" + + +def test_mfa_serial_without_provider_raises(): + with pytest.raises( + ValueError, + match="mfa_code_provider is required when mfa_serial is set", + ): + AssumeRoleCredentialsResolver( + source_resolver=AsyncMock(), + role_arn=ROLE_ARN, + mfa_serial=MFA_SERIAL, + ) diff --git a/pyproject.toml b/pyproject.toml index 152cf5f61..fb0db85bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ aws_sdk_signers = {workspace = true } [tool.pyright] typeCheckingMode = "strict" enableExperimentalFeatures = true +exclude = ["packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/**"] [tool.pytest.ini_options] asyncio_mode = "auto" # makes pytest run async tests without having to be marked with the @pytest.mark.asyncio decorator @@ -58,7 +59,10 @@ target-version = "py312" # probably not, a lot of work: DOC, D, PL, TRY select = [ "ASYNC", "C4", "E1", "E4", "E7", "E9", "F", "FURB", "G", "I", "LOG", "PIE", "RUF", "S", "T", "UP" ] -exclude = [ "packages/smithy-core/src/smithy_core/rfc3986.py" ] +exclude = [ + "packages/smithy-core/src/smithy_core/rfc3986.py", + "packages/smithy-aws-core/src/smithy_aws_core/_private/nested_clients/**", +] [tool.ruff.lint.isort] classes = ["URI"] diff --git a/uv.lock b/uv.lock index 90e8cea62..a29bca749 100644 --- a/uv.lock +++ b/uv.lock @@ -690,6 +690,9 @@ dependencies = [ ] [package.optional-dependencies] +assume-role = [ + { name = "smithy-xml" }, +] eventstream = [ { name = "smithy-aws-event-stream" }, ] @@ -707,9 +710,10 @@ requires-dist = [ { name = "smithy-core", editable = "packages/smithy-core" }, { name = "smithy-http", editable = "packages/smithy-http" }, { name = "smithy-json", marker = "extra == 'json'", editable = "packages/smithy-json" }, + { name = "smithy-xml", marker = "extra == 'assume-role'", editable = "packages/smithy-xml" }, { name = "smithy-xml", marker = "extra == 'xml'", editable = "packages/smithy-xml" }, ] -provides-extras = ["eventstream", "json", "xml"] +provides-extras = ["assume-role", "eventstream", "json", "xml"] [[package]] name = "smithy-aws-event-stream" @@ -753,7 +757,7 @@ awscrt = [ [package.metadata] requires-dist = [ - { name = "aiohttp", marker = "extra == 'aiohttp'", specifier = ">=3.14.0,<4.0" }, + { name = "aiohttp", marker = "extra == 'aiohttp'", specifier = ">=3.11.12,<4.0" }, { name = "awscrt", marker = "extra == 'awscrt'", specifier = "~=0.32.0" }, { name = "smithy-core", editable = "packages/smithy-core" }, { name = "yarl", marker = "extra == 'aiohttp'" },