diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 71dc5e1e2b..b5623f3ff8 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -1074,8 +1074,8 @@ def _context_get_or_load(self, document_uri: t.Optional[URI] = None) -> LSPConte loaded_sqlmesh_message(self.server) else: self._ensure_context_for_document(document_uri) - if isinstance(state, ContextLoaded): - return state.lsp_context + if isinstance(self.context_state, ContextLoaded): + return self.context_state.lsp_context raise RuntimeError("Context failed to load") def _ensure_context_for_document( diff --git a/tests/lsp/test_context.py b/tests/lsp/test_context.py index b463a17139..976d30fef2 100644 --- a/tests/lsp/test_context.py +++ b/tests/lsp/test_context.py @@ -2,6 +2,7 @@ from sqlmesh.core.context import Context from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.main import ContextLoaded, NoContext, SQLMeshLanguageServer from sqlmesh.lsp.uri import URI @@ -61,3 +62,32 @@ def test_lsp_context_run_test(): # Check that the result is not None and has the expected properties assert result is not None assert result.success is True + + +def test_context_get_or_load_from_no_context_with_specified_paths(): + server = SQLMeshLanguageServer(context_class=Context) + server.server.show_message = lambda *args, **kwargs: None + server.specified_paths = [Path("examples/sushi")] + + assert isinstance(server.context_state, NoContext) + + lsp_context = server._context_get_or_load() + + assert isinstance(lsp_context, LSPContext) + assert isinstance(server.context_state, ContextLoaded) + assert server.context_state.lsp_context is lsp_context + + +def test_context_get_or_load_from_no_context_via_workspace_folder(): + server = SQLMeshLanguageServer(context_class=Context) + server.server.show_message = lambda *args, **kwargs: None + server.specified_paths = None + server.workspace_folders = [Path.cwd() / "examples" / "sushi"] + + assert isinstance(server.context_state, NoContext) + + lsp_context = server._context_get_or_load() + + assert isinstance(lsp_context, LSPContext) + assert isinstance(server.context_state, ContextLoaded) + assert server.context_state.lsp_context is lsp_context