diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index d085c6fd87..a98dc0faf9 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -32,6 +32,7 @@ from pyiceberg import __version__ from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager +from pyiceberg.catalog.rest.credentials_provider import REFRESH_CREDENTIALS_ENABLED, VendedCredentialsProvider from pyiceberg.catalog.rest.response import _handle_non_200_response from pyiceberg.catalog.rest.scan_planning import ( FetchScanTasksRequest, @@ -484,7 +485,10 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | Non merged_properties = {**self.properties, **properties} if self._auth_manager: merged_properties[AUTH_MANAGER] = self._auth_manager - return load_file_io(merged_properties, location) + file_io = load_file_io(merged_properties, location) + if property_as_bool(merged_properties, REFRESH_CREDENTIALS_ENABLED, False): + file_io.set_credentials_provider(VendedCredentialsProvider(self._session, merged_properties)) + return file_io @override def supports_server_side_planning(self) -> bool: diff --git a/pyiceberg/catalog/rest/credentials_provider.py b/pyiceberg/catalog/rest/credentials_provider.py new file mode 100644 index 0000000000..87416cadc5 --- /dev/null +++ b/pyiceberg/catalog/rest/credentials_provider.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# +from datetime import datetime + +from pydantic import Field +from requests import HTTPError, Session + +from pyiceberg.catalog import URI +from pyiceberg.catalog.rest.response import _handle_non_200_response +from pyiceberg.catalog.rest.scan_planning import StorageCredential +from pyiceberg.exceptions import ValidationError, ValidationException +from pyiceberg.io import ( + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN, + S3_ACCESS_KEY_ID, + S3_SECRET_ACCESS_KEY, + S3_SESSION_TOKEN, +) +from pyiceberg.typedef import IcebergBaseModel, Properties +from pyiceberg.utils.properties import get_first_property_value + +S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms" +CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint" +REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled" + + +class LoadCredentialsResponse(IcebergBaseModel): + credentials: list[StorageCredential] = Field(alias="storage-credentials") + + +class VendedCredentialsProvider: + _session: Session + _properties: Properties + + def __init__(self, session: Session, properties: Properties): + self._session = session + self._properties = properties + + def _extract_s3_credentials_from(self, props: Properties) -> tuple[str | None, str | None, str | None, str | None]: + """Extract only S3 credentials from properties.""" + access_key = get_first_property_value(props, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID) + secret_key = get_first_property_value(props, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY) + session_token = get_first_property_value(props, S3_SESSION_TOKEN, AWS_SESSION_TOKEN) + expiry = get_first_property_value(props, S3_SESSION_TOKEN_EXPIRES_AT_MS) + + return access_key, secret_key, session_token, expiry + + def _to_credentials_property_map( + self, access_key: str | None, secret_key: str | None, session_token: str | None, expiry: str | None + ) -> Properties: + return { + S3_ACCESS_KEY_ID: access_key, + S3_SECRET_ACCESS_KEY: secret_key, + S3_SESSION_TOKEN: session_token, + S3_SESSION_TOKEN_EXPIRES_AT_MS: expiry, + } + + def needs_refresh(self) -> bool: + """Return True if the S3 session token expires within 300s.""" + expiry = get_first_property_value(self._properties, S3_SESSION_TOKEN_EXPIRES_AT_MS) + if expiry is None: + return False + expires_at = datetime.fromtimestamp(int(expiry) / 1000) + seconds_remaining = (expires_at - datetime.now()).total_seconds() + return seconds_remaining < 300 + + def _build_refresh_endpoint(self) -> str: + """Build credential refresh endpoint from properties.""" + catalog_uri = get_first_property_value(self._properties, URI) + credentials_path = get_first_property_value(self._properties, CREDENTIALS_ENDPOINT) + + if catalog_uri is None: + raise ValidationException("Invalid catalog endpoint: None") + + if credentials_path is None: + raise ValidationException("Invalid credentials endpoint: None") + + return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/") + + def _get_new_credentials(self) -> LoadCredentialsResponse | None: + try: + http_response = self._session.get(self._build_refresh_endpoint()) + http_response.raise_for_status() + return LoadCredentialsResponse.model_validate_json(http_response.text) + except HTTPError as exc: + _handle_non_200_response(exc, {}) + return None + + def get_credentials(self) -> Properties: + """Retrieve current S3 credentials, refreshing from the endpoint if near expiry.""" + access_key, secret_key, session_token, expiry = self._extract_s3_credentials_from(self._properties) + + if not self.needs_refresh(): + return self._to_credentials_property_map(access_key, secret_key, session_token, expiry) + + creds = self._get_new_credentials() + + if creds is None: + raise ValidationError("Load credential response is None") + if not creds.credentials: + raise ValueError("Invalid S3 Credentials: empty") + if len(creds.credentials) > 1: + raise ValueError("Invalid S3 Credentials: only one S3 credential should exists") + + updated_creds = self._extract_s3_credentials_from(creds.credentials[0].config) + updated_map = self._to_credentials_property_map(*updated_creds) + + # Update internal properties with new credentials + self._properties = {**self._properties, **updated_map} + + return updated_map diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py index 7dbc651214..255da19b21 100644 --- a/pyiceberg/io/__init__.py +++ b/pyiceberg/io/__init__.py @@ -32,9 +32,13 @@ from io import SEEK_SET from types import TracebackType from typing import ( + TYPE_CHECKING, Protocol, runtime_checkable, ) + +if TYPE_CHECKING: + from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider from urllib.parse import urlparse from pyiceberg.typedef import EMPTY_DICT, Properties @@ -291,6 +295,13 @@ def delete(self, location: str | InputFile | OutputFile) -> None: FileNotFoundError: When the file at the provided location does not exist. """ + def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None: # noqa: B027 + """Inject a credentials provider for refreshing vended storage credentials. + + Args: + provider (VendedCredentialsProvider): A concrete type of VendedCredentialsProvider (e.g S3VendedCredentialsProvider) + """ + LOCATION = "location" WAREHOUSE = "warehouse" diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py index 7749268ff5..e6409441fb 100644 --- a/pyiceberg/io/fsspec.py +++ b/pyiceberg/io/fsspec.py @@ -39,6 +39,7 @@ from pyiceberg.catalog import TOKEN, URI from pyiceberg.catalog.rest.auth import AUTH_MANAGER +from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider from pyiceberg.exceptions import SignError from pyiceberg.io import ( ADLS_ACCOUNT_HOST, @@ -166,9 +167,12 @@ def _file(_: Properties) -> LocalFileSystem: return LocalFileSystem(auto_mkdir=True) -def _s3(properties: Properties) -> AbstractFileSystem: +def _s3(properties: Properties, cred_provider: VendedCredentialsProvider | None) -> AbstractFileSystem: from s3fs import S3FileSystem + if cred_provider is not None and cred_provider.needs_refresh(): + properties = {**properties, **cred_provider.get_credentials()} + client_kwargs = { "endpoint_url": properties.get(S3_ENDPOINT), "aws_access_key_id": get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), @@ -319,6 +323,7 @@ def _hf(properties: Properties) -> AbstractFileSystem: } _ADLS_SCHEMES = frozenset({"abfs", "abfss", "wasb", "wasbs"}) +_S3_SCHEMES = frozenset({"s3", "s3a", "s3n"}) class FsspecInputFile(InputFile): @@ -430,8 +435,12 @@ class FsspecFileIO(FileIO): def __init__(self, properties: Properties): self._scheme_to_fs: dict[str, Callable[..., AbstractFileSystem]] = dict(SCHEME_TO_FS) self._thread_locals = threading.local() + self._credentials_provider: VendedCredentialsProvider | None = None super().__init__(properties=properties) + def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None: + self._credentials_provider = provider + @override def new_input(self, location: str) -> FsspecInputFile: """Get an FsspecInputFile instance to read bytes from the file at the given location. @@ -486,9 +495,12 @@ def _get_fs_from_uri(self, uri: "ParseResult") -> AbstractFileSystem: def get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem: """Get a filesystem for a specific scheme, cached per thread.""" - if not hasattr(self._thread_locals, "get_fs_cached"): - self._thread_locals.get_fs_cached = lru_cache(self._get_fs) + # If we have available a CredentialProvider and we detect that the tokens need to be refreshed + # then invalidate the cached fileio in order to get a new fileio with the fresh credentials + needs_refresh = self._credentials_provider and self._credentials_provider.needs_refresh() + if not hasattr(self._thread_locals, "get_fs_cached") or needs_refresh: + self._thread_locals.get_fs_cached = lru_cache(self._get_fs) return self._thread_locals.get_fs_cached(scheme, hostname) def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem: @@ -499,6 +511,9 @@ def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSyste if scheme in _ADLS_SCHEMES: return _adls(self.properties, hostname) + if scheme in _S3_SCHEMES: + return _s3(self.properties, self._credentials_provider) + return self._scheme_to_fs[scheme](self.properties) def __getstate__(self) -> dict[str, Any]: diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4ec7a73afe..8937ce1dc6 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -187,6 +187,7 @@ from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string if TYPE_CHECKING: + from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider from pyiceberg.table import FileScanTask, WriteTask logger = logging.getLogger(__name__) @@ -394,8 +395,20 @@ class PyArrowFileIO(FileIO): def __init__(self, properties: Properties = EMPTY_DICT): self.fs_by_scheme: Callable[[str, str | None], FileSystem] = lru_cache(self._initialize_fs) + self._credentials_provider: VendedCredentialsProvider | None = None super().__init__(properties=properties) + def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None: + self._credentials_provider = provider + + def _get_fs(self, scheme: str, netloc: str | None) -> FileSystem: + # If we have available a CredentialProvider and we detect that the tokens need to be refreshed + # then invalidate the cached fileio in order to get a new fileio with the fresh credentials + if self._credentials_provider and self._credentials_provider.needs_refresh(): + self.properties = {**self.properties, **self._credentials_provider.get_credentials()} + self.fs_by_scheme = lru_cache(self._initialize_fs) + return self.fs_by_scheme(scheme, netloc) + @staticmethod def parse_location(location: str, properties: Properties = EMPTY_DICT) -> tuple[str, str, str]: """Return (scheme, netloc, path) for the given location. @@ -628,7 +641,7 @@ def new_input(self, location: str) -> PyArrowFile: """ scheme, netloc, path = self.parse_location(location, self.properties) return PyArrowFile( - fs=self.fs_by_scheme(scheme, netloc), + fs=self._get_fs(scheme, netloc), location=location, path=path, buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)), @@ -646,7 +659,7 @@ def new_output(self, location: str) -> PyArrowFile: """ scheme, netloc, path = self.parse_location(location, self.properties) return PyArrowFile( - fs=self.fs_by_scheme(scheme, netloc), + fs=self._get_fs(scheme, netloc), location=location, path=path, buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)), diff --git a/tests/catalog/test_credentials_provider.py b/tests/catalog/test_credentials_provider.py new file mode 100644 index 0000000000..a828af263d --- /dev/null +++ b/tests/catalog/test_credentials_provider.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from unittest.mock import MagicMock + +import pytest + +from pyiceberg.catalog.rest.credentials_provider import ( + CREDENTIALS_ENDPOINT, + LoadCredentialsResponse, + VendedCredentialsProvider, +) +from pyiceberg.catalog.rest.scan_planning import StorageCredential + +CATALOG_URI = "http://localhost:8181" +CREDENTIALS_PATH = "v1/credentials" + +BASE_PROPS = { + "uri": CATALOG_URI, + CREDENTIALS_ENDPOINT: CREDENTIALS_PATH, + "s3.access-key-id": "initial-key", + "s3.secret-access-key": "initial-secret", + "s3.session-token": "initial-token", +} + +REFRESH_RESPONSE = LoadCredentialsResponse( + credentials=[ + StorageCredential( + prefix="s3://", + config={ + "s3.access-key-id": "refreshed-key", + "s3.secret-access-key": "refreshed-secret", + "s3.session-token": "refreshed-token", + }, + ) + ] +) + + +def _make_session(response: LoadCredentialsResponse = REFRESH_RESPONSE) -> MagicMock: + session = MagicMock() + mock_response = MagicMock() + mock_response.text = response.model_dump_json(by_alias=True) + mock_response.raise_for_status.return_value = None + session.get.return_value = mock_response + return session + + +def test_get_credentials_no_expiry_returns_static_creds() -> None: + """When no expiry is set, credentials are returned from properties without an HTTP call.""" + session = _make_session() + provider = VendedCredentialsProvider(session, BASE_PROPS) + creds = provider.get_credentials() + + session.get.assert_not_called() + assert creds["s3.access-key-id"] == "initial-key" + assert creds["s3.secret-access-key"] == "initial-secret" + assert creds["s3.session-token"] == "initial-token" + + +def test_get_credentials_far_expiry_returns_static_creds() -> None: + """When expiry is far in the future (>300s), no refresh is triggered.""" + far_future_ms = str(int((time.time() + 3600) * 1000)) # expires in 1 hour + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": far_future_ms} + session = _make_session() + provider = VendedCredentialsProvider(session, props) + creds = provider.get_credentials() + + session.get.assert_not_called() + assert creds["s3.access-key-id"] == "initial-key" + + +def test_get_credentials_near_expiry_calls_refresh_endpoint() -> None: + """When expiry is within 300s, the refresh endpoint is called and new creds returned.""" + near_expiry_ms = str(int((time.time() + 60) * 1000)) # expires in 60s + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + session = _make_session() + provider = VendedCredentialsProvider(session, props) + creds = provider.get_credentials() + + session.get.assert_called_once_with(f"{CATALOG_URI}/{CREDENTIALS_PATH}") + assert creds["s3.access-key-id"] == "refreshed-key" + assert creds["s3.secret-access-key"] == "refreshed-secret" + assert creds["s3.session-token"] == "refreshed-token" + + +def test_get_credentials_raises_on_empty_credentials() -> None: + """An empty credentials list in the refresh response raises ValueError.""" + near_expiry_ms = str(int((time.time() + 60) * 1000)) + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + empty_response = LoadCredentialsResponse(credentials=[]) + provider = VendedCredentialsProvider(_make_session(empty_response), props) + + with pytest.raises(ValueError, match="empty"): + provider.get_credentials() + + +def test_get_credentials_raises_on_multiple_credentials() -> None: + """More than one credential in the refresh response raises ValueError.""" + near_expiry_ms = str(int((time.time() + 60) * 1000)) + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + multi_response = LoadCredentialsResponse( + credentials=[ + StorageCredential(prefix="s3://", config={}), + StorageCredential(prefix="s3://b", config={}), + ] + ) + provider = VendedCredentialsProvider(_make_session(multi_response), props) + + with pytest.raises(ValueError, match="only one"): + provider.get_credentials() + + +def test_build_refresh_endpoint_strips_trailing_slash() -> None: + props = {**BASE_PROPS, "uri": "http://localhost:8181/"} + provider = VendedCredentialsProvider(MagicMock(), props) + assert provider._build_refresh_endpoint() == f"http://localhost:8181/{CREDENTIALS_PATH}" + + +def test_build_refresh_endpoint_raises_without_uri() -> None: + props = {CREDENTIALS_ENDPOINT: CREDENTIALS_PATH} + provider = VendedCredentialsProvider(MagicMock(), props) + + from pyiceberg.exceptions import ValidationException + + with pytest.raises(ValidationException): + provider._build_refresh_endpoint() + + +def test_needs_refresh_true_when_near_expiry() -> None: + near_expiry_ms = str(int((time.time() + 60) * 1000)) + provider = VendedCredentialsProvider(MagicMock(), {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms}) + assert provider.needs_refresh() is True + + +def test_needs_refresh_false_when_far_expiry() -> None: + far_expiry_ms = str(int((time.time() + 3600) * 1000)) + provider = VendedCredentialsProvider(MagicMock(), {**BASE_PROPS, "s3.session-token-expires-at-ms": far_expiry_ms}) + assert provider.needs_refresh() is False + + +def test_needs_refresh_false_when_no_expiry() -> None: + provider = VendedCredentialsProvider(MagicMock(), BASE_PROPS) + assert provider.needs_refresh() is False + + +def test_get_credentials_updates_internal_properties_after_refresh() -> None: + """After a refresh, _properties holds the new expiry so needs_refresh() sees the updated state.""" + far_future_ms = str(int((time.time() + 3600) * 1000)) + refreshed_response = LoadCredentialsResponse( + credentials=[ + StorageCredential( + prefix="s3://", + config={ + "s3.access-key-id": "new-key", + "s3.secret-access-key": "new-secret", + "s3.session-token": "new-token", + "s3.session-token-expires-at-ms": far_future_ms, + }, + ) + ] + ) + near_expiry_ms = str(int((time.time() + 60) * 1000)) + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + provider = VendedCredentialsProvider(_make_session(refreshed_response), props) + + assert provider.needs_refresh() is True + provider.get_credentials() + assert provider.needs_refresh() is False + assert provider._properties["s3.session-token-expires-at-ms"] == far_future_ms diff --git a/tests/io/test_fsspec.py b/tests/io/test_fsspec.py index 8739a5964d..a2ad258b1d 100644 --- a/tests/io/test_fsspec.py +++ b/tests/io/test_fsspec.py @@ -29,6 +29,7 @@ from requests_mock import Mocker from pyiceberg.catalog.rest.auth import AUTH_MANAGER +from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider from pyiceberg.exceptions import SignError from pyiceberg.io import fsspec from pyiceberg.io.fsspec import FsspecFileIO, S3V4RestSigner @@ -1105,3 +1106,75 @@ def auth_header(self) -> str: assert requests_mock.last_request is not None assert requests_mock.last_request.headers["Authorization"] == "Bearer via-manager" assert request.url == new_uri + + +def test_fsspec_credentials_provider_bypasses_cache() -> None: + provider = mock.MagicMock(spec=VendedCredentialsProvider) + provider.needs_refresh.return_value = True + provider.get_credentials.return_value = { + "s3.access-key-id": "refreshed-key", + "s3.secret-access-key": "refreshed-secret", + "s3.session-token": "refreshed-token", + } + + s3_fileio = FsspecFileIO(properties={"s3.endpoint": "http://localhost:9000"}) + s3_fileio.set_credentials_provider(provider) + + with mock.patch("s3fs.S3FileSystem"): + s3_fileio.new_input("s3://bucket/key1") + s3_fileio.new_input("s3://bucket/key2") + + assert provider.get_credentials.call_count == 2 + + +def test_fsspec_credentials_provider_uses_cache_when_fresh() -> None: + provider = mock.MagicMock(spec=VendedCredentialsProvider) + provider.needs_refresh.return_value = False + + s3_fileio = FsspecFileIO(properties={"s3.endpoint": "http://localhost:9000"}) + s3_fileio.set_credentials_provider(provider) + + with mock.patch("s3fs.S3FileSystem") as mock_s3fs: + s3_fileio.new_input("s3://bucket/key1") + s3_fileio.new_input("s3://bucket/key2") + + provider.get_credentials.assert_not_called() + assert mock_s3fs.call_count == 1 + + +def test_fsspec_credentials_provider_merges_fresh_creds() -> None: + provider = mock.MagicMock(spec=VendedCredentialsProvider) + provider.needs_refresh.return_value = True + provider.get_credentials.return_value = { + "s3.access-key-id": "refreshed-key", + "s3.secret-access-key": "refreshed-secret", + "s3.session-token": "refreshed-token", + } + + s3_fileio = FsspecFileIO( + properties={ + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "stale-key", + "s3.secret-access-key": "stale-secret", + "s3.session-token": "stale-token", + "s3.region": "us-east-1", + } + ) + s3_fileio.set_credentials_provider(provider) + + with mock.patch("s3fs.S3FileSystem") as mock_s3fs: + s3_fileio.new_input("s3://bucket/key") + call_kwargs = mock_s3fs.call_args[1]["client_kwargs"] + assert call_kwargs["aws_access_key_id"] == "refreshed-key" + assert call_kwargs["aws_secret_access_key"] == "refreshed-secret" + assert call_kwargs["aws_session_token"] == "refreshed-token" + + +def test_fsspec_no_credentials_provider_uses_lru_cache() -> None: + s3_fileio = FsspecFileIO(properties={"s3.endpoint": "http://localhost:9000"}) + + with mock.patch("s3fs.S3FileSystem") as mock_s3fs: + s3_fileio.new_input("s3://bucket/key1") + s3_fileio.new_input("s3://bucket/key2") + + assert mock_s3fs.call_count == 1 diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2f36661a1f..d9187615ce 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -35,6 +35,7 @@ from packaging import version from pyarrow.fs import AwsDefaultS3RetryStrategy, FileType, LocalFileSystem, S3FileSystem +from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider from pyiceberg.exceptions import ResolveError from pyiceberg.expressions import ( AlwaysFalse, @@ -5103,3 +5104,121 @@ def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCata result_sorted = result.sort_by("name") assert result_sorted["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David"] assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"] + + +def test_pyarrow_credentials_provider_merges_fresh_creds() -> None: + """_initialize_s3_fs() uses fresh credentials from the provider when needs_refresh is True.""" + provider = MagicMock(spec=VendedCredentialsProvider) + provider.needs_refresh.return_value = True + provider.get_credentials.return_value = { + "s3.access-key-id": "refreshed-key", + "s3.secret-access-key": "refreshed-secret", + "s3.session-token": "refreshed-token", + } + + s3_fileio = PyArrowFileIO( + properties={ + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "stale-key", + "s3.secret-access-key": "stale-secret", + "s3.session-token": "stale-token", + "s3.region": "us-east-1", + } + ) + s3_fileio.set_credentials_provider(provider) + + with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region", side_effect=OSError): + s3_fileio.new_input("s3://bucket/key") + call_kwargs = mock_s3fs.call_args[1] + assert call_kwargs["access_key"] == "refreshed-key" + assert call_kwargs["secret_key"] == "refreshed-secret" + assert call_kwargs["session_token"] == "refreshed-token" + + +def test_pyarrow_credentials_provider_bypasses_lru_cache() -> None: + """With needs_refresh=True, every file open calls get_credentials() — no stale fs served.""" + provider = MagicMock(spec=VendedCredentialsProvider) + provider.needs_refresh.return_value = True + provider.get_credentials.return_value = { + "s3.access-key-id": "refreshed-key", + "s3.secret-access-key": "refreshed-secret", + "s3.session-token": "refreshed-token", + } + + s3_fileio = PyArrowFileIO(properties={"s3.endpoint": "http://localhost:9000", "s3.region": "us-east-1"}) + s3_fileio.set_credentials_provider(provider) + + with patch("pyarrow.fs.S3FileSystem"), patch("pyarrow.fs.resolve_s3_region", side_effect=OSError): + s3_fileio.new_input("s3://bucket/key1") + s3_fileio.new_input("s3://bucket/key2") + + assert provider.get_credentials.call_count == 2 + + +def test_pyarrow_no_provider_preserves_lru_cache() -> None: + """Without a provider, fs_by_scheme is lru_cache-wrapped — S3FileSystem instantiated once.""" + from unittest.mock import patch + + s3_fileio = PyArrowFileIO(properties={"s3.endpoint": "http://localhost:9000", "s3.region": "us-east-1"}) + + with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region", side_effect=OSError): + s3_fileio.new_input("s3://bucket/key1") + s3_fileio.new_input("s3://bucket/key2") + + assert mock_s3fs.call_count == 1 + + +def test_pyarrow_credentials_provider_updates_properties_after_refresh() -> None: + """When needs_refresh=True, self.properties is updated with refreshed creds (intentional caching).""" + from unittest.mock import MagicMock, patch + + from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider + + provider = MagicMock(spec=VendedCredentialsProvider) + provider.needs_refresh.return_value = True + provider.get_credentials.return_value = { + "s3.access-key-id": "refreshed-key", + "s3.secret-access-key": "refreshed-secret", + "s3.session-token": "refreshed-token", + } + + s3_fileio = PyArrowFileIO( + properties={ + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "original-key", + "s3.region": "us-east-1", + } + ) + s3_fileio.set_credentials_provider(provider) + + with patch("pyarrow.fs.S3FileSystem"), patch("pyarrow.fs.resolve_s3_region", side_effect=OSError): + s3_fileio.new_input("s3://bucket/key") + + assert s3_fileio.properties["s3.access-key-id"] == "refreshed-key" + assert s3_fileio.properties["s3.endpoint"] == "http://localhost:9000" + + +def test_pyarrow_credentials_provider_skips_refresh_when_fresh() -> None: + """When needs_refresh=False, self.properties is not modified and get_credentials is not called.""" + from unittest.mock import MagicMock, patch + + from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider + + provider = MagicMock(spec=VendedCredentialsProvider) + provider.needs_refresh.return_value = False + + s3_fileio = PyArrowFileIO( + properties={ + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "original-key", + "s3.region": "us-east-1", + } + ) + s3_fileio.set_credentials_provider(provider) + + with patch("pyarrow.fs.S3FileSystem"), patch("pyarrow.fs.resolve_s3_region", side_effect=OSError): + s3_fileio.new_input("s3://bucket/key1") + s3_fileio.new_input("s3://bucket/key2") + + provider.get_credentials.assert_not_called() + assert s3_fileio.properties["s3.access-key-id"] == "original-key"