Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
* <li>Automatic region detection with fallback to us-east-1 for testing environments
* <li>Environment variable credentials provider
* <li>Custom SerDes with snake_case property naming
* <li>Optional round-trip validation toggle for performance-sensitive workloads
* </ul>
*/
public class CustomConfigExample extends DurableHandler<String, String> {
Expand Down Expand Up @@ -68,6 +69,8 @@ protected DurableConfig createConfiguration() {
return DurableConfig.builder()
.withDurableExecutionClient(durableClient)
.withSerDes(customSerDes)
// Disable the extra deserialize pass if your workload is sensitive to the added validation cost.
.withSerializationRoundTripValidation(false)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public final class DurableConfig {
private final LoggerConfig loggerConfig;
private final PollingStrategy pollingStrategy;
private final Duration checkpointDelay;
private final boolean validateSerializationRoundTrip;
private final PluginRunner pluginRunner;

private DurableConfig(Builder builder) {
Expand All @@ -109,6 +110,7 @@ private DurableConfig(Builder builder) {
this.loggerConfig = Objects.requireNonNullElseGet(builder.loggerConfig, LoggerConfig::defaults);
this.pollingStrategy = Objects.requireNonNullElse(builder.pollingStrategy, PollingStrategies.Presets.DEFAULT);
this.checkpointDelay = Objects.requireNonNullElseGet(builder.checkpointDelay, () -> Duration.ofSeconds(0));
this.validateSerializationRoundTrip = builder.validateSerializationRoundTrip;
this.pluginRunner = builder.plugins.isEmpty() ? PluginRunner.noOp() : new PluginRunner(builder.plugins);

validateConfiguration();
Expand Down Expand Up @@ -186,6 +188,19 @@ public Duration getCheckpointDelay() {
return checkpointDelay;
}

/**
* Gets whether serialized operation data should be immediately deserialized to verify round-trip compatibility.
*
* <p>When enabled, the SDK validates serialized operation results and exceptions before checkpointing them. This
* catches incompatible SerDes behavior early at the cost of an extra deserialize pass. Defaults to true, and custom
* SerDes implementations are still expected to be round-trip safe even if this validation is disabled.
*
* @return true when round-trip serialization validation is enabled
*/
public boolean shouldValidateSerializationRoundTrip() {
return validateSerializationRoundTrip;
}

/**
* Gets the plugin runner that dispatches lifecycle events to registered plugins.
*
Expand Down Expand Up @@ -293,6 +308,7 @@ public static final class Builder {
private LoggerConfig loggerConfig;
private PollingStrategy pollingStrategy;
private Duration checkpointDelay;
private boolean validateSerializationRoundTrip = true;
private List<DurableExecutionPlugin> plugins = new ArrayList<>();

public Builder() {}
Expand Down Expand Up @@ -403,6 +419,22 @@ public Builder withCheckpointDelay(Duration duration) {
return this;
}

/**
* Controls whether the SDK immediately deserializes serialized results and exceptions to verify they can be
* read back before checkpointing.
*
* <p>This validation is enabled by default. Disable it only to avoid the extra deserialize pass when the
* additional safety check is too expensive for your workload. Custom SerDes implementations are still expected
* to round-trip SDK-managed values correctly.
*
* @param validateSerializationRoundTrip true to validate serialized data with an immediate deserialize pass
* @return This builder
*/
public Builder withSerializationRoundTripValidation(boolean validateSerializationRoundTrip) {
this.validateSerializationRoundTrip = validateSerializationRoundTrip;
return this;
}

/**
* Registers one or more plugins for lifecycle event instrumentation.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public DurableConfig getConfiguration() {
* .withDurableExecutionClient(durableClient)
* .withSerDes(customSerDes) // Optional: custom SerDes for user data
* .withExecutorService(customExecutor) // Optional: custom thread pool
* .withSerializationRoundTripValidation(false) // Optional: skip extra validation deserialize pass
* .build();
* }
* }</pre>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ private void executeChildContext() {
}

private void handleChildContextSuccess(T result) {
var serialized = serializeResult(result);

if (replayChildren.get() || isVirtual || parentOperation != null && parentOperation.isOperationCompleted()) {
// Skip checkpointing if
// - parent ConcurrencyOperation has already completed, preventing race conditions where a child finishes
Expand All @@ -159,13 +161,11 @@ private void handleChildContextSuccess(T result) {
cachedOperationResult.set(DeserializedOperationResult.succeeded(result));
markAlreadyCompleted();
} else {
checkpointSuccess(result);
checkpointSuccess(result, serialized);
}
}

private void checkpointSuccess(T result) {
var serialized = serializeResult(result);

private void checkpointSuccess(T result, String serialized) {
if (serialized == null || serialized.getBytes(StandardCharsets.UTF_8).length < LARGE_RESULT_THRESHOLD) {
sendOperationUpdate(
OperationUpdate.builder().action(OperationAction.SUCCEED).payload(serialized));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ protected T deserializeResult(String result) {
* @return the serialized string
*/
protected String serializeResult(T result) {
return resultSerDes.serialize(result);
var serialized = resultSerDes.serialize(result);
if (shouldValidateSerializationRoundTrip()) {
deserializeResult(serialized);
}
return serialized;
}

/**
Expand All @@ -110,8 +114,18 @@ protected String serializeResult(T result) {
* @param throwable the exception to serialize
* @return the serialized error object
*/
@SuppressWarnings("ThrowableNotThrown")
protected ErrorObject serializeException(Throwable throwable) {
return ExceptionHelper.buildErrorObject(throwable, resultSerDes);
var error = ExceptionHelper.buildErrorObject(throwable, resultSerDes);
if (shouldValidateSerializationRoundTrip()) {
deserializeException(error);
}
return error;
}

private boolean shouldValidateSerializationRoundTrip() {
var config = getContext().getDurableConfig();
return config == null || config.shouldValidateSerializationRoundTrip();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ void testBuilder_WithCustomExecutorService() {
assertNotNull(config.getSerDes());
}

@Test
void testBuilder_SerializationRoundTripValidationDefaultsToTrue() {
var config =
DurableConfig.builder().withDurableExecutionClient(mockClient).build();

assertTrue(config.shouldValidateSerializationRoundTrip());
}

@Test
void testBuilder_WithSerializationRoundTripValidationDisabled() {
var config = DurableConfig.builder()
.withDurableExecutionClient(mockClient)
.withSerializationRoundTripValidation(false)
.build();

assertFalse(config.shouldValidateSerializationRoundTrip());
}

@Test
void testBuilder_WithAllCustomComponents() {
var config = DurableConfig.builder()
Expand Down Expand Up @@ -131,6 +149,7 @@ void testBuilder_FluentAPI() {
assertSame(builder, builder.withDurableExecutionClient(mockClient));
assertSame(builder, builder.withSerDes(mockSerDes));
assertSame(builder, builder.withExecutorService(mockExecutor));
assertSame(builder, builder.withSerializationRoundTripValidation(false));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,30 @@
import software.amazon.lambda.durable.context.DurableContextImpl;
import software.amazon.lambda.durable.exception.ChildContextFailedException;
import software.amazon.lambda.durable.exception.NonDeterministicExecutionException;
import software.amazon.lambda.durable.exception.SerDesException;
import software.amazon.lambda.durable.execution.ExecutionManager;
import software.amazon.lambda.durable.execution.ThreadContext;
import software.amazon.lambda.durable.execution.ThreadType;
import software.amazon.lambda.durable.model.OperationIdentifier;
import software.amazon.lambda.durable.model.OperationSubType;
import software.amazon.lambda.durable.serde.JacksonSerDes;
import software.amazon.lambda.durable.serde.SerDes;

/** Unit tests for ChildContextOperation. */
class ChildContextOperationTest {

private static final class SerializationOnlySerDes implements SerDes {
@Override
public String serialize(Object value) {
return "\"serialized\"";
}

@Override
public <T> T deserialize(String data, TypeToken<T> typeToken) {
throw new SerDesException("cannot deserialize");
}
}

private static final JacksonSerDes SERDES = new JacksonSerDes();

private DurableContextImpl durableContext;
Expand All @@ -49,29 +63,42 @@ void setUp() {
}

private DurableConfig createConfig() {
return createConfig(true);
}

private DurableConfig createConfig(boolean validateSerializationRoundTrip) {
return DurableConfig.builder()
.withExecutorService(Executors.newCachedThreadPool())
.withSerializationRoundTripValidation(validateSerializationRoundTrip)
.build();
}

private static final OperationIdentifier OPERATION_IDENTIFIER =
OperationIdentifier.of("1", "test-context", OperationSubType.RUN_IN_CHILD_CONTEXT);

private ChildContextOperation<String> createOperation(Function<DurableContext, String> func) {
return createOperation(func, SERDES);
}

private ChildContextOperation<String> createOperation(Function<DurableContext, String> func, SerDes serDes) {
return new ChildContextOperation<>(
OPERATION_IDENTIFIER,
func,
TypeToken.get(String.class),
RunInChildContextConfig.builder().serDes(SERDES).build(),
RunInChildContextConfig.builder().serDes(serDes).build(),
durableContext);
}

private ChildContextOperation<String> createVirtualOperation(Function<DurableContext, String> func) {
return createVirtualOperation(func, SERDES);
}

private ChildContextOperation<String> createVirtualOperation(Function<DurableContext, String> func, SerDes serDes) {
return new ChildContextOperation<>(
OPERATION_IDENTIFIER,
func,
TypeToken.get(String.class),
RunInChildContextConfig.builder().serDes(SERDES).isVirtual(true).build(),
RunInChildContextConfig.builder().serDes(serDes).isVirtual(true).build(),
durableContext);
}

Expand Down Expand Up @@ -311,6 +338,40 @@ void childSkipsSuccessCheckpointWhenParentAlreadyCompleted() throws Exception {
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
}

/** Virtual child still validates result round-trip before skipping a success checkpoint. */
@Test
void virtualChildFailsWhenResultCannotBeDeserialized() throws Exception {
when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null);

var operation = createVirtualOperation(ctx -> "result", new SerializationOnlySerDes());
operation.execute();
Thread.sleep(200);

var thrown = assertThrows(ChildContextFailedException.class, operation::get);
assertTrue(thrown.getMessage().contains(SerDesException.class.getName()));
verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL));
}

/** Virtual child can skip round-trip validation when disabled in DurableConfig. */
@Test
void virtualChildSucceedsWhenResultValidationDisabled() throws Exception {
when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null);
when(durableContext.getDurableConfig()).thenReturn(createConfig(false));

var operation = createVirtualOperation(ctx -> "result", new SerializationOnlySerDes());
operation.execute();
Thread.sleep(200);

assertEquals("result", operation.get());
verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL));
}

/** Child skips failure checkpoint when parent operation has already completed. */
@Test
void childSkipsFailureCheckpointWhenParentAlreadyCompleted() throws Exception {
Expand Down
Loading
Loading