Skip to content

Commit 86b1ae7

Browse files
feat(#1424): self.upstream property for pre-restricted ancestor access
Implements T2.2.b of the provenance trinity. Inside make(), self.upstream exposes a pre-constructed Diagram.trace(self & key) so users can read declared ancestors with provenance-safe, ergonomic syntax. Branch stacked on feat/1423-diagram-trace (#1471) for Diagram.trace(). What's added: - src/datajoint/autopopulate.py: - AutoPopulate._upstream class attribute (default None) — instance storage for the per-make() trace. - AutoPopulate.upstream property — returns the trace if set, raises DataJointError with a clear "only available inside make()" message otherwise. The error includes the fallback pattern (dj.Diagram.trace(self & key)) so the user knows the escape hatch. - In _populate_one, set self._upstream = Diagram.trace(self & dict(key)) immediately before the make() invocation block. Construction is lazy at this layer (graph copy only); the SQL fetch fires when the user accesses self.upstream[T].fetch(...). - The existing `finally` block (line 716) that resets _allow_insert now also resets _upstream to None, so subsequent attribute access raises a clear error rather than silently returning a stale trace from the previous make() call. What's not changed: - make() signature: unchanged — key remains a dict, make_kwargs work as before. self.upstream is a new attribute on self, not a new parameter. - Tripartite make pattern: self._upstream is set once before all three make() invocations, so all three phases see the same upstream view. Tests in tests/integration/test_autopopulate.py (5 new): - test_upstream_provides_pre_restricted_ancestor — basic case: make() reads self.upstream[Subject].fetch1("name") and the value is correctly pre-restricted to the current key. - test_upstream_rejects_non_ancestor — self.upstream[Unrelated] raises DataJointError ("not in this trace"). Inherited from Diagram.__getitem__. - test_upstream_unset_outside_make — accessing the property outside of make() raises with the helpful "only available inside make()" message. - test_upstream_cleared_after_make — after populate() completes, accessing the property on a fresh instance still raises (verifies the finally cleanup; no stale state). - test_upstream_seen_across_tripartite_make — both make_fetch / make_compute / make_insert see the same self.upstream value. Full regression: 17/17 autopopulate tests pass on MySQL. Slated for DataJoint 2.3. Blocked on #1471 (Diagram.trace) merging; T2.2.c (strict_provenance) is stacked on this PR.
1 parent 8d9d242 commit 86b1ae7

2 files changed

Lines changed: 220 additions & 0 deletions

File tree

src/datajoint/autopopulate.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,46 @@ class AutoPopulate:
8989
_key_source = None
9090
_allow_insert = False
9191
_jobs = None
92+
_upstream = None # set per-make() by _populate_one; see `upstream` property below
93+
94+
@property
95+
def upstream(self):
96+
"""
97+
Pre-restricted ancestor view for the current ``make(self, key)`` call.
98+
99+
Inside ``make()``, ``self.upstream`` is a ``Diagram`` constructed via
100+
:meth:`Diagram.trace(self & key) <datajoint.Diagram.trace>`. Use
101+
``self.upstream[T]`` to obtain a pre-restricted ``QueryExpression``
102+
(or ``FreeTable``, when indexed by a string) for any ancestor of
103+
``self``.
104+
105+
Reading via ``self.upstream`` is the provenance-safe pattern: the
106+
framework guarantees the restriction matches the current ``key``,
107+
and indexing a non-ancestor table raises ``DataJointError``. See
108+
:doc:`reference/specs/provenance` for the contract.
109+
110+
Raises
111+
------
112+
DataJointError
113+
If accessed outside ``make()`` execution. To construct a trace
114+
explicitly, use ``dj.Diagram.trace(self & key)``.
115+
116+
Examples
117+
--------
118+
::
119+
120+
def make(self, key):
121+
date = self.upstream[Session].fetch1("session_date")
122+
traces = self.upstream[ExtractTraces].to_arrays("trace")
123+
self.insert1({**key, "summary": compute(traces, date)})
124+
"""
125+
if self._upstream is None:
126+
raise DataJointError(
127+
"self.upstream is only available inside make(). "
128+
"Outside make(), construct a trace explicitly: "
129+
"dj.Diagram.trace(self & key)."
130+
)
131+
return self._upstream
92132

93133
class _JobsDescriptor:
94134
"""Descriptor allowing jobs access on both class and instance."""
@@ -611,6 +651,13 @@ def _populate1(
611651
logger.jobs(f"Making {key} -> {self.full_table_name}")
612652
self.__class__._allow_insert = True
613653

654+
# Pre-construct the upstream view for this make() call. Lazy — only
655+
# `dj.Diagram.trace(self & key)` runs here (graph copy); the
656+
# expensive SQL fetch fires when the user accesses self.upstream[T].
657+
from .diagram import Diagram
658+
659+
self._upstream = Diagram.trace(self & dict(key))
660+
614661
try:
615662
if not is_generator:
616663
make(dict(key), **(make_kwargs or {}))
@@ -668,6 +715,10 @@ def _populate1(
668715
return True
669716
finally:
670717
self.__class__._allow_insert = False
718+
# Clear the per-make() upstream view so subsequent attribute
719+
# access raises a clear error rather than silently using a
720+
# stale trace from the previous make() call.
721+
self._upstream = None
671722

672723
def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]:
673724
"""

tests/integration/test_autopopulate.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,175 @@ def make_insert(self, key, result, scale):
354354
assert row["result"] == 1000 # 200 * 5
355355

356356

357+
# =========================================================================
358+
# #1424: self.upstream pre-restricted ancestor access in make()
359+
# =========================================================================
360+
361+
362+
def test_upstream_provides_pre_restricted_ancestor(prefix, connection_test):
363+
"""make() can read self.upstream[Ancestor] and get pre-restricted data."""
364+
schema = dj.Schema(f"{prefix}_upstream_basic", connection=connection_test)
365+
366+
@schema
367+
class Subject(dj.Lookup):
368+
definition = """
369+
subject_id : int32
370+
---
371+
name : varchar(64)
372+
"""
373+
contents = [(1, "alice"), (2, "bob")]
374+
375+
@schema
376+
class Greeting(dj.Computed):
377+
definition = """
378+
-> Subject
379+
---
380+
greeting : varchar(128)
381+
"""
382+
383+
def make(self, key):
384+
# Provenance-safe read: self.upstream pre-restricted to current key
385+
name = self.upstream[Subject].fetch1("name")
386+
self.insert1({**key, "greeting": f"Hello, {name}!"})
387+
388+
Greeting.populate()
389+
assert (Greeting & {"subject_id": 1}).fetch1("greeting") == "Hello, alice!"
390+
assert (Greeting & {"subject_id": 2}).fetch1("greeting") == "Hello, bob!"
391+
392+
393+
def test_upstream_rejects_non_ancestor(prefix, connection_test):
394+
"""self.upstream[T] for a non-ancestor table raises inside make()."""
395+
schema = dj.Schema(f"{prefix}_upstream_non_ancestor", connection=connection_test)
396+
397+
@schema
398+
class Subject(dj.Lookup):
399+
definition = """
400+
subject_id : int32
401+
"""
402+
contents = [(1,)]
403+
404+
@schema
405+
class Unrelated(dj.Lookup):
406+
definition = """
407+
u_id : int32
408+
"""
409+
contents = [(99,)]
410+
411+
captured_errors: list[Exception] = []
412+
413+
@schema
414+
class Bad(dj.Computed):
415+
definition = """
416+
-> Subject
417+
---
418+
ok : tinyint
419+
"""
420+
421+
def make(self, key):
422+
try:
423+
self.upstream[Unrelated]
424+
except DataJointError as exc:
425+
captured_errors.append(exc)
426+
# Insert anyway so populate doesn't fail
427+
self.insert1({**key, "ok": 1})
428+
429+
Bad.populate()
430+
assert len(captured_errors) == 1
431+
assert "not in this trace" in str(captured_errors[0]).lower()
432+
433+
434+
def test_upstream_unset_outside_make(prefix, connection_test):
435+
"""Accessing self.upstream outside of make() raises a clear error."""
436+
schema = dj.Schema(f"{prefix}_upstream_outside_make", connection=connection_test)
437+
438+
@schema
439+
class Source(dj.Lookup):
440+
definition = """
441+
source_id : int32
442+
"""
443+
contents = [(1,)]
444+
445+
@schema
446+
class Derived(dj.Computed):
447+
definition = """
448+
-> Source
449+
---
450+
val : int32
451+
"""
452+
453+
def make(self, key):
454+
self.insert1({**key, "val": 0})
455+
456+
with pytest.raises(DataJointError, match="only available inside make"):
457+
Derived().upstream
458+
459+
460+
def test_upstream_cleared_after_make(prefix, connection_test):
461+
"""After a make() call completes, self.upstream is reset (no stale state)."""
462+
schema = dj.Schema(f"{prefix}_upstream_cleared", connection=connection_test)
463+
464+
@schema
465+
class Source(dj.Lookup):
466+
definition = """
467+
source_id : int32
468+
"""
469+
contents = [(1,)]
470+
471+
@schema
472+
class Derived(dj.Computed):
473+
definition = """
474+
-> Source
475+
---
476+
val : int32
477+
"""
478+
479+
def make(self, key):
480+
self.insert1({**key, "val": 0})
481+
482+
Derived.populate()
483+
# The class attribute defaults to None; the per-instance _upstream
484+
# set during make() must have been cleared by the finally block.
485+
# Probe via the public property — should raise the "outside make" error.
486+
with pytest.raises(DataJointError, match="only available inside make"):
487+
Derived().upstream
488+
489+
490+
def test_upstream_seen_across_tripartite_make(prefix, connection_test):
491+
"""The tripartite make() invocation pattern sees the same self.upstream
492+
across all three phases (fetch / compute / insert)."""
493+
schema = dj.Schema(f"{prefix}_upstream_tripartite", connection=connection_test)
494+
495+
@schema
496+
class Source(dj.Lookup):
497+
definition = """
498+
source_id : int32
499+
---
500+
value : int32
501+
"""
502+
contents = [(1, 100), (2, 200)]
503+
504+
@schema
505+
class TriComputed(dj.Computed):
506+
definition = """
507+
-> Source
508+
---
509+
result : int32
510+
"""
511+
512+
def make_fetch(self, key):
513+
return (self.upstream[Source].fetch1("value"),)
514+
515+
def make_compute(self, key, value):
516+
return (value * 2,)
517+
518+
def make_insert(self, key, doubled):
519+
self.insert1({**key, "result": doubled})
520+
521+
TriComputed.populate()
522+
assert (TriComputed & {"source_id": 1}).fetch1("result") == 200
523+
assert (TriComputed & {"source_id": 2}).fetch1("result") == 400
524+
525+
357526
def test_populate_reserve_jobs_respects_restrictions(clean_autopopulate, subject, experiment):
358527
"""Regression test for #1413: populate() with reserve_jobs=True must honour restrictions.
359528

0 commit comments

Comments
 (0)