From e34ef9e0d0a8f0be9669e5eaaf93bac711cff3b4 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 19 Jun 2026 15:39:58 -0700 Subject: [PATCH] fix: make sure serialized result is deserialize-able --- .../examples/general/CustomConfigExample.java | 3 + .../amazon/lambda/durable/DurableConfig.java | 32 ++++ .../amazon/lambda/durable/DurableHandler.java | 1 + .../operation/ChildContextOperation.java | 8 +- .../SerializableDurableOperation.java | 18 ++- .../lambda/durable/DurableConfigTest.java | 19 +++ .../operation/ChildContextOperationTest.java | 65 +++++++- .../SerializableDurableOperationTest.java | 150 ++++++++++++++++++ 8 files changed, 288 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/general/CustomConfigExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/general/CustomConfigExample.java index 28d248395..5bbd902b0 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/general/CustomConfigExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/general/CustomConfigExample.java @@ -31,6 +31,7 @@ *
  • Automatic region detection with fallback to us-east-1 for testing environments *
  • Environment variable credentials provider *
  • Custom SerDes with snake_case property naming + *
  • Optional round-trip validation toggle for performance-sensitive workloads * */ public class CustomConfigExample extends DurableHandler { @@ -68,6 +69,8 @@ protected DurableConfig createConfiguration() { return DurableConfig.builder() .withDurableExecutionClient(durableClient) .withSerDes(customSerDes) + // Disable the extra deserialize pass if your workload is sensitive to the added validation cost. + .withSerializationRoundTripValidation(false) .build(); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java index a21251e19..f3610082e 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java @@ -98,6 +98,7 @@ public final class DurableConfig { private final LoggerConfig loggerConfig; private final PollingStrategy pollingStrategy; private final Duration checkpointDelay; + private final boolean validateSerializationRoundTrip; private final PluginRunner pluginRunner; private DurableConfig(Builder builder) { @@ -109,6 +110,7 @@ private DurableConfig(Builder builder) { this.loggerConfig = Objects.requireNonNullElseGet(builder.loggerConfig, LoggerConfig::defaults); this.pollingStrategy = Objects.requireNonNullElse(builder.pollingStrategy, PollingStrategies.Presets.DEFAULT); this.checkpointDelay = Objects.requireNonNullElseGet(builder.checkpointDelay, () -> Duration.ofSeconds(0)); + this.validateSerializationRoundTrip = builder.validateSerializationRoundTrip; this.pluginRunner = builder.plugins.isEmpty() ? PluginRunner.noOp() : new PluginRunner(builder.plugins); validateConfiguration(); @@ -186,6 +188,19 @@ public Duration getCheckpointDelay() { return checkpointDelay; } + /** + * Gets whether serialized operation data should be immediately deserialized to verify round-trip compatibility. + * + *

    When enabled, the SDK validates serialized operation results and exceptions before checkpointing them. This + * catches incompatible SerDes behavior early at the cost of an extra deserialize pass. Defaults to true, and custom + * SerDes implementations are still expected to be round-trip safe even if this validation is disabled. + * + * @return true when round-trip serialization validation is enabled + */ + public boolean shouldValidateSerializationRoundTrip() { + return validateSerializationRoundTrip; + } + /** * Gets the plugin runner that dispatches lifecycle events to registered plugins. * @@ -293,6 +308,7 @@ public static final class Builder { private LoggerConfig loggerConfig; private PollingStrategy pollingStrategy; private Duration checkpointDelay; + private boolean validateSerializationRoundTrip = true; private List plugins = new ArrayList<>(); public Builder() {} @@ -403,6 +419,22 @@ public Builder withCheckpointDelay(Duration duration) { return this; } + /** + * Controls whether the SDK immediately deserializes serialized results and exceptions to verify they can be + * read back before checkpointing. + * + *

    This validation is enabled by default. Disable it only to avoid the extra deserialize pass when the + * additional safety check is too expensive for your workload. Custom SerDes implementations are still expected + * to round-trip SDK-managed values correctly. + * + * @param validateSerializationRoundTrip true to validate serialized data with an immediate deserialize pass + * @return This builder + */ + public Builder withSerializationRoundTripValidation(boolean validateSerializationRoundTrip) { + this.validateSerializationRoundTrip = validateSerializationRoundTrip; + return this; + } + /** * Registers one or more plugins for lifecycle event instrumentation. * diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableHandler.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableHandler.java index a31c94d00..6fceee077 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableHandler.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableHandler.java @@ -117,6 +117,7 @@ public DurableConfig getConfiguration() { * .withDurableExecutionClient(durableClient) * .withSerDes(customSerDes) // Optional: custom SerDes for user data * .withExecutorService(customExecutor) // Optional: custom thread pool + * .withSerializationRoundTripValidation(false) // Optional: skip extra validation deserialize pass * .build(); * } * } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index acdda283b..e5f287fed 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -149,6 +149,8 @@ private void executeChildContext() { } private void handleChildContextSuccess(T result) { + var serialized = serializeResult(result); + if (replayChildren.get() || isVirtual || parentOperation != null && parentOperation.isOperationCompleted()) { // Skip checkpointing if // - parent ConcurrencyOperation has already completed, preventing race conditions where a child finishes @@ -159,13 +161,11 @@ private void handleChildContextSuccess(T result) { cachedOperationResult.set(DeserializedOperationResult.succeeded(result)); markAlreadyCompleted(); } else { - checkpointSuccess(result); + checkpointSuccess(result, serialized); } } - private void checkpointSuccess(T result) { - var serialized = serializeResult(result); - + private void checkpointSuccess(T result, String serialized) { if (serialized == null || serialized.getBytes(StandardCharsets.UTF_8).length < LARGE_RESULT_THRESHOLD) { sendOperationUpdate( OperationUpdate.builder().action(OperationAction.SUCCEED).payload(serialized)); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java index 6ccf24e0f..5eeae4f6b 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java @@ -101,7 +101,11 @@ protected T deserializeResult(String result) { * @return the serialized string */ protected String serializeResult(T result) { - return resultSerDes.serialize(result); + var serialized = resultSerDes.serialize(result); + if (shouldValidateSerializationRoundTrip()) { + deserializeResult(serialized); + } + return serialized; } /** @@ -110,8 +114,18 @@ protected String serializeResult(T result) { * @param throwable the exception to serialize * @return the serialized error object */ + @SuppressWarnings("ThrowableNotThrown") protected ErrorObject serializeException(Throwable throwable) { - return ExceptionHelper.buildErrorObject(throwable, resultSerDes); + var error = ExceptionHelper.buildErrorObject(throwable, resultSerDes); + if (shouldValidateSerializationRoundTrip()) { + deserializeException(error); + } + return error; + } + + private boolean shouldValidateSerializationRoundTrip() { + var config = getContext().getDurableConfig(); + return config == null || config.shouldValidateSerializationRoundTrip(); } /** diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java index 94d13f5ef..9f1dceddf 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java @@ -87,6 +87,24 @@ void testBuilder_WithCustomExecutorService() { assertNotNull(config.getSerDes()); } + @Test + void testBuilder_SerializationRoundTripValidationDefaultsToTrue() { + var config = + DurableConfig.builder().withDurableExecutionClient(mockClient).build(); + + assertTrue(config.shouldValidateSerializationRoundTrip()); + } + + @Test + void testBuilder_WithSerializationRoundTripValidationDisabled() { + var config = DurableConfig.builder() + .withDurableExecutionClient(mockClient) + .withSerializationRoundTripValidation(false) + .build(); + + assertFalse(config.shouldValidateSerializationRoundTrip()); + } + @Test void testBuilder_WithAllCustomComponents() { var config = DurableConfig.builder() @@ -131,6 +149,7 @@ void testBuilder_FluentAPI() { assertSame(builder, builder.withDurableExecutionClient(mockClient)); assertSame(builder, builder.withSerDes(mockSerDes)); assertSame(builder, builder.withExecutorService(mockExecutor)); + assertSame(builder, builder.withSerializationRoundTripValidation(false)); } @Test diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java index ac56262a4..098ecb1f6 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java @@ -24,16 +24,30 @@ import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.ChildContextFailedException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; +import software.amazon.lambda.durable.exception.SerDesException; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.ThreadContext; import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.JacksonSerDes; +import software.amazon.lambda.durable.serde.SerDes; /** Unit tests for ChildContextOperation. */ class ChildContextOperationTest { + private static final class SerializationOnlySerDes implements SerDes { + @Override + public String serialize(Object value) { + return "\"serialized\""; + } + + @Override + public T deserialize(String data, TypeToken typeToken) { + throw new SerDesException("cannot deserialize"); + } + } + private static final JacksonSerDes SERDES = new JacksonSerDes(); private DurableContextImpl durableContext; @@ -49,8 +63,13 @@ void setUp() { } private DurableConfig createConfig() { + return createConfig(true); + } + + private DurableConfig createConfig(boolean validateSerializationRoundTrip) { return DurableConfig.builder() .withExecutorService(Executors.newCachedThreadPool()) + .withSerializationRoundTripValidation(validateSerializationRoundTrip) .build(); } @@ -58,20 +77,28 @@ private DurableConfig createConfig() { OperationIdentifier.of("1", "test-context", OperationSubType.RUN_IN_CHILD_CONTEXT); private ChildContextOperation createOperation(Function func) { + return createOperation(func, SERDES); + } + + private ChildContextOperation createOperation(Function func, SerDes serDes) { return new ChildContextOperation<>( OPERATION_IDENTIFIER, func, TypeToken.get(String.class), - RunInChildContextConfig.builder().serDes(SERDES).build(), + RunInChildContextConfig.builder().serDes(serDes).build(), durableContext); } private ChildContextOperation createVirtualOperation(Function func) { + return createVirtualOperation(func, SERDES); + } + + private ChildContextOperation createVirtualOperation(Function func, SerDes serDes) { return new ChildContextOperation<>( OPERATION_IDENTIFIER, func, TypeToken.get(String.class), - RunInChildContextConfig.builder().serDes(SERDES).isVirtual(true).build(), + RunInChildContextConfig.builder().serDes(serDes).isVirtual(true).build(), durableContext); } @@ -311,6 +338,40 @@ void childSkipsSuccessCheckpointWhenParentAlreadyCompleted() throws Exception { .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); } + /** Virtual child still validates result round-trip before skipping a success checkpoint. */ + @Test + void virtualChildFailsWhenResultCannotBeDeserialized() throws Exception { + when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null); + + var operation = createVirtualOperation(ctx -> "result", new SerializationOnlySerDes()); + operation.execute(); + Thread.sleep(200); + + var thrown = assertThrows(ChildContextFailedException.class, operation::get); + assertTrue(thrown.getMessage().contains(SerDesException.class.getName())); + verify(executionManager, never()) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + verify(executionManager, never()) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL)); + } + + /** Virtual child can skip round-trip validation when disabled in DurableConfig. */ + @Test + void virtualChildSucceedsWhenResultValidationDisabled() throws Exception { + when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null); + when(durableContext.getDurableConfig()).thenReturn(createConfig(false)); + + var operation = createVirtualOperation(ctx -> "result", new SerializationOnlySerDes()); + operation.execute(); + Thread.sleep(200); + + assertEquals("result", operation.get()); + verify(executionManager, never()) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + verify(executionManager, never()) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL)); + } + /** Child skips failure checkpoint when parent operation has already completed. */ @Test void childSkipsFailureCheckpointWhenParentAlreadyCompleted() throws Exception { diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java index 195636fd8..f9599801b 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java @@ -20,6 +20,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.ErrorObject; @@ -27,7 +28,9 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; +import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.client.DurableExecutionClient; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.IllegalDurableOperationException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; @@ -42,6 +45,32 @@ class SerializableDurableOperationTest { + private static final class TrackingSerDes extends JacksonSerDes { + private final AtomicInteger deserializeCount = new AtomicInteger(0); + + @Override + public T deserialize(String data, TypeToken typeToken) { + deserializeCount.incrementAndGet(); + return super.deserialize(data, typeToken); + } + + int getDeserializeCount() { + return deserializeCount.get(); + } + } + + private static final class SerializationOnlySerDes implements SerDes { + @Override + public String serialize(Object value) { + return "\"serialized\""; + } + + @Override + public T deserialize(String data, TypeToken typeToken) { + throw new SerDesException("cannot deserialize"); + } + } + private static final String OPERATION_ID = "1"; private static final String CONTEXT_ID = "1-step"; private static final String OPERATION_NAME = "name"; @@ -315,6 +344,73 @@ public String get() { op.get(); } + @Test + void serializeResultValidatesRoundTrip() { + var serDes = new TrackingSerDes(); + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, serDes, durableContext) { + @Override + protected void start() {} + + @Override + protected void replay(Operation existing) {} + + @Override + public String get() { + assertEquals("\"abc\"", serializeResult("abc")); + assertEquals(1, serDes.getDeserializeCount()); + return RESULT; + } + }; + + op.get(); + } + + @Test + void serializeResultThrowsWhenRoundTripFails() { + var serDes = new SerializationOnlySerDes(); + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, serDes, durableContext) { + @Override + protected void start() {} + + @Override + protected void replay(Operation existing) {} + + @Override + public String get() { + var thrown = assertThrows(SerDesException.class, () -> serializeResult("abc")); + assertEquals("cannot deserialize", thrown.getMessage()); + return RESULT; + } + }; + + op.get(); + } + + @Test + void serializeResultSkipsRoundTripValidationWhenDisabled() { + when(durableContext.getDurableConfig()).thenReturn(configWithSerializationValidation(false)); + + var serDes = new SerializationOnlySerDes(); + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, serDes, durableContext) { + @Override + protected void start() {} + + @Override + protected void replay(Operation existing) {} + + @Override + public String get() { + assertEquals("\"serialized\"", serializeResult("abc")); + return RESULT; + } + }; + + op.get(); + } + @Test void deserializeException() { SerializableDurableOperation op = @@ -341,6 +437,53 @@ public String get() { op.get(); } + @Test + void serializeExceptionValidatesRoundTrip() { + var serDes = new TrackingSerDes(); + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, serDes, durableContext) { + @Override + protected void start() {} + + @Override + protected void replay(Operation existing) {} + + @Override + public String get() { + var error = serializeException(new RuntimeException("test exception")); + assertEquals(RuntimeException.class.getName(), error.errorType()); + assertEquals(1, serDes.getDeserializeCount()); + return RESULT; + } + }; + + op.get(); + } + + @Test + void serializeExceptionSkipsRoundTripValidationWhenDisabled() { + when(durableContext.getDurableConfig()).thenReturn(configWithSerializationValidation(false)); + + var serDes = new SerializationOnlySerDes(); + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, serDes, durableContext) { + @Override + protected void start() {} + + @Override + protected void replay(Operation existing) {} + + @Override + public String get() { + var error = serializeException(new RuntimeException("test exception")); + assertEquals(RuntimeException.class.getName(), error.errorType()); + return RESULT; + } + }; + + op.get(); + } + @Test void polling() { SerializableDurableOperation op = @@ -386,4 +529,11 @@ public String get() { op.execute(); verify(executionManager, times(1)).sendOperationUpdate(update.build()); } + + private DurableConfig configWithSerializationValidation(boolean validateSerializationRoundTrip) { + return DurableConfig.builder() + .withDurableExecutionClient(mock(DurableExecutionClient.class)) + .withSerializationRoundTripValidation(validateSerializationRoundTrip) + .build(); + } }