feat: add ruff linter/formatter and fix Python 3.10 test compatibility

- Add ruff configuration to pyproject.toml with sensible defaults
- Add pre-commit hooks for ruff check and format
- Update CI to run ruff check and format verification on tests/
- Fix Python 3.10 mock.patch compatibility in CLI tests
  - Use importlib to get actual modules, bypassing shadowed click groups
  - Add get_cli_module() helper and update patch_client_for_module()
- Apply ruff formatting to all source and test files
- Move .claude/skills to gitignore (local-only skills)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Teng Lin 2026-01-09 14:59:47 -05:00
parent 892de6f76d
commit 9bd715af0d
81 changed files with 1548 additions and 1358 deletions

View file

@ -1,57 +0,0 @@
---
name: matrix
description: Run tests across multiple Python versions using Docker containers
---
# Multi-Version Test Runner
Run the test suite across Python 3.10-3.14 using Docker containers in parallel.
## Quick Start
```bash
./dev/test-versions.sh
```
## Usage Examples
```bash
# Run all versions (3.10, 3.11, 3.12, 3.13, 3.14)
./dev/test-versions.sh
# Run specific versions only
./dev/test-versions.sh 3.12 3.13
# Include readonly e2e tests (requires auth)
./dev/test-versions.sh -r
# Pass pytest arguments (after --)
./dev/test-versions.sh -- -k test_encoder -v
# Combine options
./dev/test-versions.sh -r 3.12 -- -k test_encoder
```
## When to Use
- Before committing changes that might affect Python version compatibility
- When testing syntax or feature compatibility across versions
- To validate all supported Python versions pass locally before CI
## Requirements
- Docker must be running
- First run pulls Python images (~50MB each)
- Subsequent runs use cached pip packages for speed (~30s per version)
## Output
Shows pass/fail status for each Python version with test summary.
On failure, displays last 30 lines of output for debugging.
## Cache Management
To clear pip cache volumes:
```bash
docker volume ls | grep notebooklm-pip-cache | xargs docker volume rm
```

View file

@ -29,7 +29,10 @@ jobs:
pip install -e ".[all]"
- name: Run linting
run: ruff check src/
run: ruff check src/ tests/
- name: Check formatting
run: ruff format --check src/ tests/
- name: Run type checking
run: mypy src/notebooklm --ignore-missing-imports

1
.gitignore vendored
View file

@ -17,3 +17,4 @@ captured_rpcs/
.worktrees/
.worktree/
.sisyphus/
.claude/

7
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.6
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View file

@ -89,3 +89,31 @@ exclude = ["tests/"]
module = "notebooklm.cli.*"
warn_return_any = false
strict_optional = true
[tool.ruff]
target-version = "py310"
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"SIM", # flake8-simplify
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # function call in default argument (Click uses this)
]
[tool.ruff.lint.isort]
known-first-party = ["notebooklm"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"

View file

@ -16,52 +16,52 @@ Note:
__version__ = "0.1.1"
# Public API: Authentication
from .auth import AuthTokens, DEFAULT_STORAGE_PATH
from .auth import DEFAULT_STORAGE_PATH, AuthTokens
# Public API: Client
from .client import NotebookLMClient
# Public API: Types and dataclasses
from .types import (
Notebook,
NotebookDescription,
SuggestedTopic,
Source,
Artifact,
GenerationStatus,
ReportSuggestion,
Note,
ConversationTurn,
AskResult,
ChatMode,
# Exceptions
SourceError,
SourceProcessingError,
SourceTimeoutError,
SourceNotFoundError,
# Enums for configuration
StudioContentType,
AudioFormat,
AudioLength,
VideoFormat,
VideoStyle,
QuizQuantity,
QuizDifficulty,
InfographicOrientation,
InfographicDetail,
SlideDeckFormat,
SlideDeckLength,
ReportFormat,
ChatGoal,
ChatResponseLength,
DriveMimeType,
ExportType,
SourceStatus,
)
# Public API: RPC errors (needed for exception handling)
from .rpc import RPCError
# Public API: Types and dataclasses
from .types import (
Artifact,
AskResult,
AudioFormat,
AudioLength,
ChatGoal,
ChatMode,
ChatResponseLength,
ConversationTurn,
DriveMimeType,
ExportType,
GenerationStatus,
InfographicDetail,
InfographicOrientation,
Note,
Notebook,
NotebookDescription,
QuizDifficulty,
QuizQuantity,
ReportFormat,
ReportSuggestion,
SlideDeckFormat,
SlideDeckLength,
Source,
# Exceptions
SourceError,
SourceNotFoundError,
SourceProcessingError,
SourceStatus,
SourceTimeoutError,
# Enums for configuration
StudioContentType,
SuggestedTopic,
VideoFormat,
VideoStyle,
)
__all__ = [
"__version__",
# Client (main entry point)

View file

@ -6,30 +6,31 @@ Quizzes, Flashcards, Infographics, Slide Decks, Data Tables, and Mind Maps.
"""
import asyncio
import builtins
import logging
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any
import httpx
from ._core import ClientCore
from .auth import load_httpx_cookies
from .rpc import (
RPCMethod,
RPCError,
StudioContentType,
ArtifactStatus,
AudioFormat,
AudioLength,
VideoFormat,
VideoStyle,
QuizQuantity,
QuizDifficulty,
InfographicOrientation,
ExportType,
InfographicDetail,
InfographicOrientation,
QuizDifficulty,
QuizQuantity,
ReportFormat,
RPCError,
RPCMethod,
SlideDeckFormat,
SlideDeckLength,
ReportFormat,
ExportType,
StudioContentType,
VideoFormat,
VideoStyle,
)
from .types import Artifact, GenerationStatus, ReportSuggestion
@ -73,9 +74,7 @@ class ArtifactsAPI:
# List/Get Operations
# =========================================================================
async def list(
self, notebook_id: str, artifact_type: Optional[int] = None
) -> List[Artifact]:
async def list(self, notebook_id: str, artifact_type: int | None = None) -> list[Artifact]:
"""List all artifacts in a notebook, including mind maps.
This returns all AI-generated content: Audio Overviews, Video Overviews,
@ -93,7 +92,7 @@ class ArtifactsAPI:
Returns:
List of Artifact objects.
"""
artifacts: List[Artifact] = []
artifacts: list[Artifact] = []
# Fetch studio artifacts (audio, video, reports, etc.)
params = [[2], notebook_id, 'NOT artifact.status = "ARTIFACT_STATUS_SUGGESTED"']
@ -121,7 +120,10 @@ class ArtifactsAPI:
for mm_data in mind_maps:
mind_map_artifact = Artifact.from_mind_map(mm_data)
if mind_map_artifact is not None: # None means deleted (status=2)
if artifact_type is None or mind_map_artifact.artifact_type == artifact_type:
if (
artifact_type is None
or mind_map_artifact.artifact_type == artifact_type
):
artifacts.append(mind_map_artifact)
except (RPCError, httpx.HTTPError) as e:
# Network/API errors - log and continue with studio artifacts
@ -131,7 +133,7 @@ class ArtifactsAPI:
return artifacts
async def get(self, notebook_id: str, artifact_id: str) -> Optional[Artifact]:
async def get(self, notebook_id: str, artifact_id: str) -> Artifact | None:
"""Get a specific artifact by ID.
Args:
@ -147,37 +149,37 @@ class ArtifactsAPI:
return artifact
return None
async def list_audio(self, notebook_id: str) -> List[Artifact]:
async def list_audio(self, notebook_id: str) -> builtins.list[Artifact]:
"""List audio overview artifacts."""
return await self.list(notebook_id, StudioContentType.AUDIO.value)
async def list_video(self, notebook_id: str) -> List[Artifact]:
async def list_video(self, notebook_id: str) -> builtins.list[Artifact]:
"""List video overview artifacts."""
return await self.list(notebook_id, StudioContentType.VIDEO.value)
async def list_reports(self, notebook_id: str) -> List[Artifact]:
async def list_reports(self, notebook_id: str) -> builtins.list[Artifact]:
"""List report artifacts (Briefing Doc, Study Guide, Blog Post)."""
return await self.list(notebook_id, StudioContentType.REPORT.value)
async def list_quizzes(self, notebook_id: str) -> List[Artifact]:
async def list_quizzes(self, notebook_id: str) -> builtins.list[Artifact]:
"""List quiz artifacts (type 4 with variant 2)."""
all_type4 = await self.list(notebook_id, StudioContentType.QUIZ_FLASHCARD.value)
return [a for a in all_type4 if a.is_quiz]
async def list_flashcards(self, notebook_id: str) -> List[Artifact]:
async def list_flashcards(self, notebook_id: str) -> builtins.list[Artifact]:
"""List flashcard artifacts (type 4 with variant 1)."""
all_type4 = await self.list(notebook_id, StudioContentType.QUIZ_FLASHCARD.value)
return [a for a in all_type4 if a.is_flashcards]
async def list_infographics(self, notebook_id: str) -> List[Artifact]:
async def list_infographics(self, notebook_id: str) -> builtins.list[Artifact]:
"""List infographic artifacts."""
return await self.list(notebook_id, StudioContentType.INFOGRAPHIC.value)
async def list_slide_decks(self, notebook_id: str) -> List[Artifact]:
async def list_slide_decks(self, notebook_id: str) -> builtins.list[Artifact]:
"""List slide deck artifacts."""
return await self.list(notebook_id, StudioContentType.SLIDE_DECK.value)
async def list_data_tables(self, notebook_id: str) -> List[Artifact]:
async def list_data_tables(self, notebook_id: str) -> builtins.list[Artifact]:
"""List data table artifacts."""
return await self.list(notebook_id, StudioContentType.DATA_TABLE.value)
@ -188,11 +190,11 @@ class ArtifactsAPI:
async def generate_audio(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
language: str = "en",
instructions: Optional[str] = None,
audio_format: Optional[AudioFormat] = None,
audio_length: Optional[AudioLength] = None,
instructions: str | None = None,
audio_format: AudioFormat | None = None,
audio_length: AudioLength | None = None,
) -> GenerationStatus:
"""Generate an Audio Overview (podcast).
@ -245,11 +247,11 @@ class ArtifactsAPI:
async def generate_video(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
language: str = "en",
instructions: Optional[str] = None,
video_format: Optional[VideoFormat] = None,
video_style: Optional[VideoStyle] = None,
instructions: str | None = None,
video_format: VideoFormat | None = None,
video_style: VideoStyle | None = None,
) -> GenerationStatus:
"""Generate a Video Overview.
@ -305,9 +307,9 @@ class ArtifactsAPI:
self,
notebook_id: str,
report_format: ReportFormat = ReportFormat.BRIEFING_DOC,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
language: str = "en",
custom_prompt: Optional[str] = None,
custom_prompt: str | None = None,
) -> GenerationStatus:
"""Generate a report artifact.
@ -395,7 +397,7 @@ class ArtifactsAPI:
async def generate_study_guide(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
language: str = "en",
) -> GenerationStatus:
"""Generate a study guide report.
@ -420,10 +422,10 @@ class ArtifactsAPI:
async def generate_quiz(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
instructions: Optional[str] = None,
quantity: Optional[QuizQuantity] = None,
difficulty: Optional[QuizDifficulty] = None,
source_ids: builtins.list[str] | None = None,
instructions: str | None = None,
quantity: QuizQuantity | None = None,
difficulty: QuizDifficulty | None = None,
) -> GenerationStatus:
"""Generate a quiz.
@ -477,10 +479,10 @@ class ArtifactsAPI:
async def generate_flashcards(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
instructions: Optional[str] = None,
quantity: Optional[QuizQuantity] = None,
difficulty: Optional[QuizDifficulty] = None,
source_ids: builtins.list[str] | None = None,
instructions: str | None = None,
quantity: QuizQuantity | None = None,
difficulty: QuizDifficulty | None = None,
) -> GenerationStatus:
"""Generate flashcards.
@ -533,11 +535,11 @@ class ArtifactsAPI:
async def generate_infographic(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
language: str = "en",
instructions: Optional[str] = None,
orientation: Optional[InfographicOrientation] = None,
detail_level: Optional[InfographicDetail] = None,
instructions: str | None = None,
orientation: InfographicOrientation | None = None,
detail_level: InfographicDetail | None = None,
) -> GenerationStatus:
"""Generate an infographic.
@ -585,11 +587,11 @@ class ArtifactsAPI:
async def generate_slide_deck(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
language: str = "en",
instructions: Optional[str] = None,
slide_format: Optional[SlideDeckFormat] = None,
slide_length: Optional[SlideDeckLength] = None,
instructions: str | None = None,
slide_format: SlideDeckFormat | None = None,
slide_length: SlideDeckLength | None = None,
) -> GenerationStatus:
"""Generate a slide deck.
@ -639,9 +641,9 @@ class ArtifactsAPI:
async def generate_data_table(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
language: str = "en",
instructions: Optional[str] = None,
instructions: str | None = None,
) -> GenerationStatus:
"""Generate a data table.
@ -689,7 +691,7 @@ class ArtifactsAPI:
async def generate_mind_map(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
source_ids: builtins.list[str] | None = None,
) -> dict[str, Any]:
"""Generate an interactive mind map.
@ -766,7 +768,7 @@ class ArtifactsAPI:
# =========================================================================
async def download_audio(
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
self, notebook_id: str, output_path: str, artifact_id: str | None = None
) -> str:
"""Download an Audio Overview to a file.
@ -782,8 +784,10 @@ class ArtifactsAPI:
# Filter for completed audio artifacts
audio_candidates = [
a for a in artifacts_data
if isinstance(a, list) and len(a) > 4
a
for a in artifacts_data
if isinstance(a, list)
and len(a) > 4
and a[2] == StudioContentType.AUDIO
and a[4] == ArtifactStatus.COMPLETED
]
@ -826,7 +830,7 @@ class ArtifactsAPI:
raise ValueError(f"Failed to parse audio artifact structure: {e}")
async def download_video(
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
self, notebook_id: str, output_path: str, artifact_id: str | None = None
) -> str:
"""Download a Video Overview to a file.
@ -842,8 +846,10 @@ class ArtifactsAPI:
# Filter for completed video artifacts
video_candidates = [
a for a in artifacts_data
if isinstance(a, list) and len(a) > 4
a
for a in artifacts_data
if isinstance(a, list)
and len(a) > 4
and a[2] == StudioContentType.VIDEO
and a[4] == ArtifactStatus.COMPLETED
]
@ -902,7 +908,7 @@ class ArtifactsAPI:
raise ValueError(f"Failed to parse video artifact structure: {e}")
async def download_infographic(
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
self, notebook_id: str, output_path: str, artifact_id: str | None = None
) -> str:
"""Download an Infographic to a file.
@ -918,8 +924,10 @@ class ArtifactsAPI:
# Filter for completed infographic artifacts
info_candidates = [
a for a in artifacts_data
if isinstance(a, list) and len(a) > 4
a
for a in artifacts_data
if isinstance(a, list)
and len(a) > 4
and a[2] == StudioContentType.INFOGRAPHIC
and a[4] == ArtifactStatus.COMPLETED
]
@ -962,7 +970,7 @@ class ArtifactsAPI:
raise ValueError(f"Failed to parse infographic structure: {e}")
async def download_slide_deck(
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
self, notebook_id: str, output_path: str, artifact_id: str | None = None
) -> str:
"""Download a slide deck as a PDF file.
@ -978,8 +986,10 @@ class ArtifactsAPI:
# Filter for completed slide deck artifacts
slide_candidates = [
a for a in artifacts_data
if isinstance(a, list) and len(a) > 4
a
for a in artifacts_data
if isinstance(a, list)
and len(a) > 4
and a[2] == StudioContentType.SLIDE_DECK
and a[4] == ArtifactStatus.COMPLETED
]
@ -1036,9 +1046,7 @@ class ArtifactsAPI:
)
return True
async def rename(
self, notebook_id: str, artifact_id: str, new_title: str
) -> None:
async def rename(self, notebook_id: str, artifact_id: str, new_title: str) -> None:
"""Rename an artifact.
Args:
@ -1095,7 +1103,7 @@ class ArtifactsAPI:
initial_interval: float = 2.0,
max_interval: float = 10.0,
timeout: float = 300.0,
poll_interval: Optional[float] = None, # Deprecated, use initial_interval
poll_interval: float | None = None, # Deprecated, use initial_interval
) -> GenerationStatus:
"""Wait for a generation task to complete.
@ -1118,6 +1126,7 @@ class ArtifactsAPI:
# Backward compatibility: poll_interval overrides initial_interval
if poll_interval is not None:
import warnings
warnings.warn(
"poll_interval is deprecated, use initial_interval instead",
DeprecationWarning,
@ -1204,8 +1213,8 @@ class ArtifactsAPI:
async def export(
self,
notebook_id: str,
artifact_id: Optional[str] = None,
content: Optional[str] = None,
artifact_id: str | None = None,
content: str | None = None,
title: str = "Export",
export_type: ExportType = ExportType.DOCS,
) -> Any:
@ -1238,8 +1247,8 @@ class ArtifactsAPI:
async def suggest_reports(
self,
notebook_id: str,
source_ids: Optional[List[str]] = None,
) -> List[ReportSuggestion]:
source_ids: builtins.list[str] | None = None,
) -> builtins.list[ReportSuggestion]:
"""Get AI-suggested report formats for a notebook.
Args:
@ -1276,12 +1285,14 @@ class ArtifactsAPI:
if result and isinstance(result, list):
for item in result:
if isinstance(item, list) and len(item) >= 5:
suggestions.append(ReportSuggestion(
title=item[0] if isinstance(item[0], str) else "",
description=item[1] if isinstance(item[1], str) else "",
prompt=item[4] if len(item) > 4 and isinstance(item[4], str) else "",
audience_level=item[5] if len(item) > 5 else 2,
))
suggestions.append(
ReportSuggestion(
title=item[0] if isinstance(item[0], str) else "",
description=item[1] if isinstance(item[1], str) else "",
prompt=item[4] if len(item) > 4 and isinstance(item[4], str) else "",
audience_level=item[5] if len(item) > 5 else 2,
)
)
return suggestions
@ -1289,7 +1300,9 @@ class ArtifactsAPI:
# Private Helpers
# =========================================================================
async def _call_generate(self, notebook_id: str, params: List[Any]) -> GenerationStatus:
async def _call_generate(
self, notebook_id: str, params: builtins.list[Any]
) -> GenerationStatus:
"""Make a generation RPC call with error handling.
Wraps the RPC call to handle UserDisplayableError (rate limiting/quota)
@ -1320,7 +1333,7 @@ class ArtifactsAPI:
)
raise
async def _list_raw(self, notebook_id: str) -> List[Any]:
async def _list_raw(self, notebook_id: str) -> builtins.list[Any]:
"""Get raw artifact list data."""
params = [[2], notebook_id, 'NOT artifact.status = "ARTIFACT_STATUS_SUGGESTED"']
result = await self._core.rpc_call(
@ -1333,7 +1346,7 @@ class ArtifactsAPI:
return result[0] if isinstance(result[0], list) else result
return []
async def _get_source_ids(self, notebook_id: str) -> List[str]:
async def _get_source_ids(self, notebook_id: str) -> builtins.list[str]:
"""Extract source IDs from notebook data."""
params = [notebook_id, None, [2], None, 0]
notebook_data = await self._core.rpc_call(
@ -1364,8 +1377,8 @@ class ArtifactsAPI:
return source_ids
async def _download_urls_batch(
self, urls_and_paths: List[Tuple[str, str]]
) -> List[str]:
self, urls_and_paths: builtins.list[tuple[str, str]]
) -> builtins.list[str]:
"""Download multiple files using httpx with proper cookie handling.
Args:
@ -1376,7 +1389,7 @@ class ArtifactsAPI:
"""
from pathlib import Path
downloaded: List[str] = []
downloaded: list[str] = []
# Load cookies with domain info for cross-domain redirect handling
cookies = load_httpx_cookies()
@ -1450,11 +1463,27 @@ class ArtifactsAPI:
"""Parse generation API result into GenerationStatus."""
if result and isinstance(result, list) and len(result) > 0:
artifact_data = result[0]
artifact_id = artifact_data[0] if isinstance(artifact_data, list) and len(artifact_data) > 0 else None
status_code = artifact_data[4] if isinstance(artifact_data, list) and len(artifact_data) > 4 else None
artifact_id = (
artifact_data[0]
if isinstance(artifact_data, list) and len(artifact_data) > 0
else None
)
status_code = (
artifact_data[4]
if isinstance(artifact_data, list) and len(artifact_data) > 4
else None
)
if artifact_id:
status = "in_progress" if status_code == 1 else "completed" if status_code == 3 else "pending"
status = (
"in_progress"
if status_code == 1
else "completed"
if status_code == 3
else "pending"
)
return GenerationStatus(task_id=artifact_id, status=status)
return GenerationStatus(task_id="", status="failed", error="Generation failed - no artifact_id returned")
return GenerationStatus(
task_id="", status="failed", error="Generation failed - no artifact_id returned"
)

View file

@ -8,11 +8,11 @@ import json
import logging
import os
import uuid
from typing import Any, Optional
from urllib.parse import urlencode, quote
from typing import Any
from urllib.parse import quote, urlencode
from ._core import ClientCore
from .rpc import RPCMethod, QUERY_URL
from .rpc import QUERY_URL, RPCMethod
from .types import AskResult, ConversationTurn
logger = logging.getLogger(__name__)
@ -50,8 +50,8 @@ class ChatAPI:
self,
notebook_id: str,
question: str,
source_ids: Optional[list[str]] = None,
conversation_id: Optional[str] = None,
source_ids: list[str] | None = None,
conversation_id: str | None = None,
) -> AskResult:
"""Ask the notebook a question.
@ -114,9 +114,7 @@ class ChatAPI:
self._core._reqid_counter += 100000
url_params = {
"bl": os.environ.get(
"NOTEBOOKLM_BL", "boq_labs-tailwind-frontend_20251221.14_p0"
),
"bl": os.environ.get("NOTEBOOKLM_BL", "boq_labs-tailwind-frontend_20251221.14_p0"),
"hl": "en",
"_reqid": str(self._core._reqid_counter),
"rt": "c",
@ -136,9 +134,7 @@ class ChatAPI:
if answer_text:
turns = self._core.get_cached_conversation(conversation_id)
turn_number = len(turns) + 1
self._core.cache_conversation_turn(
conversation_id, question, answer_text, turn_number
)
self._core.cache_conversation_turn(conversation_id, question, answer_text, turn_number)
else:
turns = self._core.get_cached_conversation(conversation_id)
turn_number = len(turns)
@ -187,7 +183,7 @@ class ChatAPI:
for turn in cached
]
def clear_cache(self, conversation_id: Optional[str] = None) -> bool:
def clear_cache(self, conversation_id: str | None = None) -> bool:
"""Clear conversation cache.
Args:
@ -201,9 +197,9 @@ class ChatAPI:
async def configure(
self,
notebook_id: str,
goal: Optional[Any] = None,
response_length: Optional[Any] = None,
custom_prompt: Optional[str] = None,
goal: Any | None = None,
response_length: Any | None = None,
custom_prompt: str | None = None,
) -> None:
"""Configure chat persona and response settings for a notebook.
@ -298,7 +294,7 @@ class ChatAPI:
return source_ids
def _build_conversation_history(self, conversation_id: str) -> Optional[list]:
def _build_conversation_history(self, conversation_id: str) -> list | None:
"""Build conversation history for follow-up requests."""
turns = self._core.get_cached_conversation(conversation_id)
if not turns:
@ -347,7 +343,7 @@ class ChatAPI:
)
return longest_answer
def _extract_answer_from_chunk(self, json_str: str) -> tuple[Optional[str], bool]:
def _extract_answer_from_chunk(self, json_str: str) -> tuple[str | None, bool]:
"""Extract answer text from a response chunk."""
try:
data = json.loads(json_str)

View file

@ -3,19 +3,19 @@
import logging
import os
from collections import OrderedDict
from typing import Any, Optional
from typing import Any
from urllib.parse import urlencode
import httpx
from .auth import AuthTokens
from .rpc import (
RPCMethod,
RPCError,
BATCHEXECUTE_URL,
encode_rpc_request,
RPCError,
RPCMethod,
build_request_body,
decode_response,
encode_rpc_request,
)
# Enable RPC debug output via environment variable
@ -57,7 +57,7 @@ class ClientCore:
"""
self.auth = auth
self._timeout = timeout
self._http_client: Optional[httpx.AsyncClient] = None
self._http_client: httpx.AsyncClient | None = None
# Request ID counter for chat API (must be unique per request)
self._reqid_counter: int = 100000
# OrderedDict for FIFO eviction when cache exceeds MAX_CONVERSATION_CACHE_SIZE
@ -216,11 +216,13 @@ class ClientCore:
self._conversation_cache.popitem(last=False)
self._conversation_cache[conversation_id] = []
self._conversation_cache[conversation_id].append({
"query": query,
"answer": answer,
"turn_number": turn_number,
})
self._conversation_cache[conversation_id].append(
{
"query": query,
"answer": answer,
"turn_number": turn_number,
}
)
def get_cached_conversation(self, conversation_id: str) -> list[dict[str, Any]]:
"""Get cached conversation turns.
@ -233,7 +235,7 @@ class ClientCore:
"""
return self._conversation_cache.get(conversation_id, [])
def clear_conversation_cache(self, conversation_id: Optional[str] = None) -> bool:
def clear_conversation_cache(self, conversation_id: str | None = None) -> bool:
"""Clear conversation cache.
Args:

View file

@ -1,6 +1,6 @@
"""Notebook operations API."""
from typing import Any, Optional
from typing import Any
from ._core import ClientCore
from .rpc import RPCMethod
@ -169,10 +169,12 @@ class NotebooksAPI:
topics_list = result[1][0] if isinstance(result[1][0], list) else []
for topic in topics_list:
if isinstance(topic, list) and len(topic) >= 2:
suggested_topics.append(SuggestedTopic(
question=topic[0] if isinstance(topic[0], str) else "",
prompt=topic[1] if isinstance(topic[1], str) else "",
))
suggested_topics.append(
SuggestedTopic(
question=topic[0] if isinstance(topic[0], str) else "",
prompt=topic[1] if isinstance(topic[1], str) else "",
)
)
return NotebookDescription(summary=summary, suggested_topics=suggested_topics)
@ -209,7 +211,7 @@ class NotebooksAPI:
)
async def share(
self, notebook_id: str, public: bool = True, artifact_id: Optional[str] = None
self, notebook_id: str, public: bool = True, artifact_id: str | None = None
) -> dict:
"""Toggle notebook sharing.
@ -252,9 +254,7 @@ class NotebooksAPI:
"artifact_id": artifact_id,
}
def get_share_url(
self, notebook_id: str, artifact_id: Optional[str] = None
) -> str:
def get_share_url(self, notebook_id: str, artifact_id: str | None = None) -> str:
"""Get share URL for a notebook or artifact.
This does NOT toggle sharing - it just returns the URL format.

View file

@ -5,7 +5,8 @@ user-created notes in notebooks. Notes are distinct from artifacts -
they are user-created content, not AI-generated.
"""
from typing import Any, List, Optional
import builtins
from typing import Any
from ._core import ClientCore
from .rpc import RPCMethod
@ -37,7 +38,7 @@ class NotesAPI:
"""
self._core = core
async def list(self, notebook_id: str) -> List[Note]:
async def list(self, notebook_id: str) -> list[Note]:
"""List all text notes in the notebook.
This excludes:
@ -59,15 +60,13 @@ class NotesAPI:
continue
content = self._extract_content(item)
is_mind_map = content and (
'"children":' in content or '"nodes":' in content
)
is_mind_map = content and ('"children":' in content or '"nodes":' in content)
if not is_mind_map:
notes.append(self._parse_note(item, notebook_id))
return notes
async def get(self, notebook_id: str, note_id: str) -> Optional[Note]:
async def get(self, notebook_id: str, note_id: str) -> Note | None:
"""Get a specific note by ID.
Args:
@ -173,7 +172,7 @@ class NotesAPI:
)
return True
async def list_mind_maps(self, notebook_id: str) -> List[Any]:
async def list_mind_maps(self, notebook_id: str) -> builtins.list[Any]:
"""List all mind maps in the notebook.
Mind maps are stored in the same internal structure as notes but
@ -227,7 +226,7 @@ class NotesAPI:
# Private Helpers
# =========================================================================
async def _get_all_notes_and_mind_maps(self, notebook_id: str) -> List[Any]:
async def _get_all_notes_and_mind_maps(self, notebook_id: str) -> builtins.list[Any]:
"""Fetch all notes and mind maps from the API."""
params = [notebook_id]
result = await self._core.rpc_call(
@ -236,25 +235,16 @@ class NotesAPI:
source_path=f"/notebook/{notebook_id}",
allow_null=True,
)
if (
result
and isinstance(result, list)
and len(result) > 0
and isinstance(result[0], list)
):
if result and isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
notes_list = result[0]
valid_notes = []
for item in notes_list:
if (
isinstance(item, list)
and len(item) > 0
and isinstance(item[0], str)
):
if isinstance(item, list) and len(item) > 0 and isinstance(item[0], str):
valid_notes.append(item)
return valid_notes
return []
def _is_deleted(self, item: List[Any]) -> bool:
def _is_deleted(self, item: builtins.list[Any]) -> bool:
"""Check if a note/mind map item is deleted (status=2).
Deleted items have structure: ['id', None, 2]
@ -270,7 +260,7 @@ class NotesAPI:
return False
return item[1] is None and item[2] == 2
def _extract_content(self, item: List[Any]) -> Optional[str]:
def _extract_content(self, item: builtins.list[Any]) -> str | None:
"""Extract content string from note/mind map item."""
if len(item) <= 1:
return None
@ -281,7 +271,7 @@ class NotesAPI:
return item[1][1]
return None
def _parse_note(self, item: List[Any], notebook_id: str) -> Note:
def _parse_note(self, item: builtins.list[Any], notebook_id: str) -> Note:
"""Parse a raw note item into a Note object."""
note_id = item[0] if len(item) > 0 else ""

View file

@ -4,7 +4,7 @@ Provides operations for starting research sessions, polling for results,
and importing discovered sources into notebooks.
"""
from typing import Any, Optional
from typing import Any
from ._core import ClientCore
from .rpc import RPCMethod
@ -44,7 +44,7 @@ class ResearchAPI:
query: str,
source: str = "web",
mode: str = "fast",
) -> Optional[dict[str, Any]]:
) -> dict[str, Any] | None:
"""Start a research session.
Args:
@ -117,11 +117,7 @@ class ResearchAPI:
return {"status": "no_research"}
# Unwrap if needed
if (
isinstance(result[0], list)
and len(result[0]) > 0
and isinstance(result[0][0], list)
):
if isinstance(result[0], list) and len(result[0]) > 0 and isinstance(result[0][0], list):
result = result[0]
# Find most recent task
@ -145,13 +141,9 @@ class ResearchAPI:
if isinstance(sources_and_summary, list) and len(sources_and_summary) >= 1:
sources_data = (
sources_and_summary[0]
if isinstance(sources_and_summary[0], list)
else []
sources_and_summary[0] if isinstance(sources_and_summary[0], list) else []
)
if len(sources_and_summary) >= 2 and isinstance(
sources_and_summary[1], str
):
if len(sources_and_summary) >= 2 and isinstance(sources_and_summary[1], str):
summary = sources_and_summary[1]
parsed_sources = []
@ -246,9 +238,7 @@ class ResearchAPI:
for src_data in result:
if isinstance(src_data, list) and len(src_data) >= 2:
src_id = (
src_data[0][0]
if src_data[0] and isinstance(src_data[0], list)
else None
src_data[0][0] if src_data[0] and isinstance(src_data[0], list) else None
)
if src_id:
imported.append({"id": src_id, "title": src_data[1]})

View file

@ -1,17 +1,18 @@
"""Source operations API."""
import asyncio
import builtins
import logging
import re
from datetime import datetime
from pathlib import Path
from time import monotonic
from typing import Any, Dict, List, Optional, Union
from typing import Any
import httpx
from ._core import ClientCore
from .rpc import RPCMethod, UPLOAD_URL
from .rpc import UPLOAD_URL, RPCMethod
from .rpc.types import SourceStatus
from .types import (
Source,
@ -91,22 +92,27 @@ class SourcesAPI:
if isinstance(url_list, list) and len(url_list) > 0:
url = url_list[0]
# Detect YouTube vs other URLs
if 'youtube.com' in url or 'youtu.be' in url:
if "youtube.com" in url or "youtu.be" in url:
source_type = "youtube"
else:
source_type = "url"
# Extract file info if no URL
if not url and title:
if title.endswith('.pdf'):
if title.endswith(".pdf"):
source_type = "pdf"
elif title.endswith(('.txt', '.md', '.doc', '.docx')):
elif title.endswith((".txt", ".md", ".doc", ".docx")):
source_type = "text_file"
elif title.endswith(('.xls', '.xlsx', '.csv')):
elif title.endswith((".xls", ".xlsx", ".csv")):
source_type = "spreadsheet"
# Check for file upload indicator
if source_type == "text" and len(src) > 2 and isinstance(src[2], list) and len(src[2]) > 1:
if (
source_type == "text"
and len(src) > 2
and isinstance(src[2], list)
and len(src[2]) > 1
):
if isinstance(src[2][1], int) and src[2][1] > 0:
source_type = "upload"
@ -132,18 +138,20 @@ class SourcesAPI:
):
status = status_code
sources.append(Source(
id=str(src_id),
title=title,
url=url,
source_type=source_type,
created_at=created_at,
status=status,
))
sources.append(
Source(
id=str(src_id),
title=title,
url=url,
source_type=source_type,
created_at=created_at,
status=status,
)
)
return sources
async def get(self, notebook_id: str, source_id: str) -> Optional[Source]:
async def get(self, notebook_id: str, source_id: str) -> Source | None:
"""Get details of a specific source.
Args:
@ -200,7 +208,7 @@ class SourcesAPI:
"""
start = monotonic()
interval = initial_interval
last_status: Optional[int] = None
last_status: int | None = None
while True:
# Check timeout before each poll
@ -233,10 +241,10 @@ class SourcesAPI:
async def wait_for_sources(
self,
notebook_id: str,
source_ids: List[str],
source_ids: builtins.list[str],
timeout: float = 120.0,
**kwargs: Any,
) -> List[Source]:
) -> builtins.list[Source]:
"""Wait for multiple sources to become ready in parallel.
Args:
@ -263,8 +271,7 @@ class SourcesAPI:
)
"""
tasks = [
self.wait_until_ready(notebook_id, sid, timeout=timeout, **kwargs)
for sid in source_ids
self.wait_until_ready(notebook_id, sid, timeout=timeout, **kwargs) for sid in source_ids
]
return list(await asyncio.gather(*tasks))
@ -353,8 +360,8 @@ class SourcesAPI:
async def add_file(
self,
notebook_id: str,
file_path: Union[str, Path],
mime_type: Optional[str] = None,
file_path: str | Path,
mime_type: str | None = None,
wait: bool = False,
wait_timeout: float = 120.0,
) -> Source:
@ -397,9 +404,7 @@ class SourcesAPI:
source_id = await self._register_file_source(notebook_id, filename)
# Step 2: Start resumable upload with the SOURCE_ID from step 1
upload_url = await self._start_resumable_upload(
notebook_id, filename, file_size, source_id
)
upload_url = await self._start_resumable_upload(notebook_id, filename, file_size, source_id)
# Step 3: Stream upload file content (memory-efficient)
await self._upload_file_streaming(upload_url, file_path)
@ -452,7 +457,15 @@ class SourcesAPI:
# Drive source structure: [[file_id, mime_type, 1, title], null x9, 1]
source_data = [
[file_id, mime_type, 1, title],
None, None, None, None, None, None, None, None, None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1,
]
params = [
@ -552,7 +565,7 @@ class SourcesAPI:
# False means stale, True means fresh
return result is True
async def get_guide(self, notebook_id: str, source_id: str) -> Dict[str, Any]:
async def get_guide(self, notebook_id: str, source_id: str) -> dict[str, Any]:
"""Get AI-generated summary and keywords for a specific source.
This is the "Source Guide" feature shown when clicking on a source
@ -596,22 +609,18 @@ class SourcesAPI:
# Private helper methods
# =========================================================================
def _extract_youtube_video_id(self, url: str) -> Optional[str]:
def _extract_youtube_video_id(self, url: str) -> str | None:
"""Extract YouTube video ID from various URL formats."""
# Short URLs: youtu.be/VIDEO_ID
match = re.match(r"https?://youtu\.be/([a-zA-Z0-9_-]+)", url)
if match:
return match.group(1)
# Standard watch URLs: youtube.com/watch?v=VIDEO_ID
match = re.match(
r"https?://(?:www\.)?youtube\.com/watch\?v=([a-zA-Z0-9_-]+)", url
)
match = re.match(r"https?://(?:www\.)?youtube\.com/watch\?v=([a-zA-Z0-9_-]+)", url)
if match:
return match.group(1)
# Shorts URLs: youtube.com/shorts/VIDEO_ID
match = re.match(
r"https?://(?:www\.)?youtube\.com/shorts/([a-zA-Z0-9_-]+)", url
)
match = re.match(r"https?://(?:www\.)?youtube\.com/shorts/([a-zA-Z0-9_-]+)", url)
if match:
return match.group(1)
return None
@ -666,6 +675,7 @@ class SourcesAPI:
# Parse SOURCE_ID from response - handle various nesting formats
# API returns different structures: [[[[id]]]], [[[id]]], [[id]], etc.
if result and isinstance(result, list):
def extract_id(data):
"""Recursively extract first string from nested lists."""
if isinstance(data, str):
@ -704,11 +714,13 @@ class SourcesAPI:
"x-goog-upload-protocol": "resumable",
}
body = json.dumps({
"PROJECT_ID": notebook_id,
"SOURCE_NAME": filename,
"SOURCE_ID": source_id,
})
body = json.dumps(
{
"PROJECT_ID": notebook_id,
"SOURCE_NAME": filename,
"SOURCE_ID": source_id,
}
)
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, content=body)

View file

@ -32,7 +32,7 @@ import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from typing import Any
import httpx
@ -78,7 +78,7 @@ class AuthTokens:
return "; ".join(f"{k}={v}" for k, v in self.cookies.items())
@classmethod
async def from_storage(cls, path: Optional[Path] = None) -> "AuthTokens":
async def from_storage(cls, path: Path | None = None) -> "AuthTokens":
"""Create AuthTokens from Playwright storage state file.
This is the recommended way to create AuthTokens for programmatic use.
@ -134,8 +134,7 @@ def extract_cookies_from_storage(storage_state: dict[str, Any]) -> dict[str, str
missing = MINIMUM_REQUIRED_COOKIES - set(cookies.keys())
if missing:
raise ValueError(
f"Missing required cookies: {missing}\n"
f"Run 'notebooklm login' to authenticate."
f"Missing required cookies: {missing}\nRun 'notebooklm login' to authenticate."
)
return cookies
@ -164,8 +163,7 @@ def extract_csrf_from_html(html: str, final_url: str = "") -> str:
# Check if we were redirected to login page
if "accounts.google.com" in final_url or "accounts.google.com" in html:
raise ValueError(
"Authentication expired or invalid. "
"Run 'notebooklm login' to re-authenticate."
"Authentication expired or invalid. Run 'notebooklm login' to re-authenticate."
)
raise ValueError(
f"CSRF token not found in HTML. Final URL: {final_url}\n"
@ -196,8 +194,7 @@ def extract_session_id_from_html(html: str, final_url: str = "") -> str:
if not match:
if "accounts.google.com" in final_url or "accounts.google.com" in html:
raise ValueError(
"Authentication expired or invalid. "
"Run 'notebooklm login' to re-authenticate."
"Authentication expired or invalid. Run 'notebooklm login' to re-authenticate."
)
raise ValueError(
f"Session ID not found in HTML. Final URL: {final_url}\n"
@ -206,7 +203,7 @@ def extract_session_id_from_html(html: str, final_url: str = "") -> str:
return match.group(1)
def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
def _load_storage_state(path: Path | None = None) -> dict[str, Any]:
"""Load Playwright storage state from file or environment variable.
This is a shared helper used by load_auth_from_storage() and load_httpx_cookies()
@ -231,8 +228,7 @@ def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
if path:
if not path.exists():
raise FileNotFoundError(
f"Storage file not found: {path}\n"
f"Run 'notebooklm login' to authenticate first."
f"Storage file not found: {path}\nRun 'notebooklm login' to authenticate first."
)
return json.loads(path.read_text())
@ -257,7 +253,7 @@ def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
raise ValueError(
"NOTEBOOKLM_AUTH_JSON must contain valid Playwright storage state "
"with a 'cookies' key.\n"
"Expected format: {\"cookies\": [{\"name\": \"SID\", \"value\": \"...\", ...}]}"
'Expected format: {"cookies": [{"name": "SID", "value": "...", ...}]}'
)
return storage_state
@ -266,14 +262,13 @@ def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
if not storage_path.exists():
raise FileNotFoundError(
f"Storage file not found: {storage_path}\n"
f"Run 'notebooklm login' to authenticate first."
f"Storage file not found: {storage_path}\nRun 'notebooklm login' to authenticate first."
)
return json.loads(storage_path.read_text())
def load_auth_from_storage(path: Optional[Path] = None) -> dict[str, str]:
def load_auth_from_storage(path: Path | None = None) -> dict[str, str]:
"""Load Google cookies from storage.
Loads authentication cookies with the following precedence:
@ -336,7 +331,7 @@ def _is_allowed_cookie_domain(domain: str) -> bool:
return False
def load_httpx_cookies(path: Optional[Path] = None) -> "httpx.Cookies":
def load_httpx_cookies(path: Path | None = None) -> "httpx.Cookies":
"""Load cookies as an httpx.Cookies object for authenticated downloads.
Unlike load_auth_from_storage() which returns a simple dict, this function

View file

@ -16,67 +16,65 @@ Re-exports from helpers for backward compatibility with tests.
"""
# Command groups (subcommand style)
from .source import source
from .artifact import artifact
from .generate import generate
from .chat import register_chat_commands
from .download import download
from .generate import generate
from .helpers import (
# Display
ARTIFACT_TYPE_DISPLAY,
ARTIFACT_TYPE_MAP,
BROWSER_PROFILE_DIR,
# Context
CONTEXT_FILE,
clear_context,
# Console
console,
detect_source_type,
get_artifact_type_display,
get_auth_tokens,
# Auth
get_client,
get_current_conversation,
get_current_notebook,
get_source_type_display,
handle_auth_error,
# Errors
handle_error,
json_error_response,
# Output
json_output_response,
require_notebook,
resolve_artifact_id,
resolve_notebook_id,
resolve_source_id,
# Async
run_async,
set_current_conversation,
set_current_notebook,
# Decorators
with_client,
)
from .note import note
from .skill import skill
from .notebook import register_notebook_commands
from .options import (
artifact_option,
generate_options,
json_option,
# Individual option decorators
notebook_option,
output_option,
source_option,
# Composite decorators
standard_options,
wait_option,
)
from .research import research
# Register functions (top-level command style)
from .session import register_session_commands
from .notebook import register_notebook_commands
from .chat import register_chat_commands
from .helpers import (
# Console
console,
# Async
run_async,
# Auth
get_client,
get_auth_tokens,
# Context
CONTEXT_FILE,
BROWSER_PROFILE_DIR,
get_current_notebook,
set_current_notebook,
clear_context,
get_current_conversation,
set_current_conversation,
require_notebook,
resolve_notebook_id,
resolve_source_id,
resolve_artifact_id,
# Errors
handle_error,
handle_auth_error,
# Decorators
with_client,
# Output
json_output_response,
json_error_response,
# Display
ARTIFACT_TYPE_DISPLAY,
ARTIFACT_TYPE_MAP,
get_artifact_type_display,
detect_source_type,
get_source_type_display,
)
from .options import (
# Individual option decorators
notebook_option,
json_option,
wait_option,
source_option,
artifact_option,
output_option,
# Composite decorators
standard_options,
generate_options,
)
from .skill import skill
from .source import source
__all__ = [
# Command groups (subcommand style)

View file

@ -19,13 +19,13 @@ from rich.table import Table
from ..client import NotebookLMClient
from ..rpc import ExportType
from .helpers import (
ARTIFACT_TYPE_MAP,
console,
get_artifact_type_display,
json_output_response,
require_notebook,
resolve_artifact_id,
with_client,
json_output_response,
get_artifact_type_display,
ARTIFACT_TYPE_MAP,
)
@ -83,9 +83,7 @@ def artifact():
def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
"""List artifacts in a notebook."""
nb_id = require_notebook(notebook_id)
type_filter = (
None if artifact_type == "all" else ARTIFACT_TYPE_MAP.get(artifact_type)
)
type_filter = None if artifact_type == "all" else ARTIFACT_TYPE_MAP.get(artifact_type)
async def _run():
async with NotebookLMClient(client_auth) as client:
@ -97,6 +95,7 @@ def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
nb = await client.notebooks.get(nb_id)
if json_output:
def _get_status_str(art):
if art.is_completed:
return "completed"
@ -118,9 +117,7 @@ def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
"type_id": art.artifact_type,
"status": _get_status_str(art),
"status_id": art.status,
"created_at": art.created_at.isoformat()
if art.created_at
else None,
"created_at": art.created_at.isoformat() if art.created_at else None,
}
for i, art in enumerate(artifacts, 1)
],
@ -144,9 +141,7 @@ def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
type_display = get_artifact_type_display(
art.artifact_type, art.variant, art.report_subtype
)
created = (
art.created_at.strftime("%Y-%m-%d %H:%M") if art.created_at else "-"
)
created = art.created_at.strftime("%Y-%m-%d %H:%M") if art.created_at else "-"
status = (
"completed"
if art.is_completed
@ -290,9 +285,7 @@ def artifact_delete(ctx, artifact_id, notebook_id, yes, client_auth):
help="Notebook ID (uses current if not set). Supports partial IDs.",
)
@click.option("--title", required=True, help="Title for exported document")
@click.option(
"--type", "export_type", type=click.Choice(["docs", "sheets"]), default="docs"
)
@click.option("--type", "export_type", type=click.Choice(["docs", "sheets"]), default="docs")
@with_client
def artifact_export(ctx, artifact_id, notebook_id, title, export_type, client_auth):
"""Export artifact to Google Docs/Sheets.
@ -384,7 +377,8 @@ def artifact_wait(ctx, artifact_id, notebook_id, timeout, interval, json_output,
try:
status = await client.artifacts.wait_for_completion(
nb_id, resolved_id,
nb_id,
resolved_id,
poll_interval=float(interval),
timeout=float(timeout),
)
@ -410,11 +404,13 @@ def artifact_wait(ctx, artifact_id, notebook_id, timeout, interval, json_output,
except TimeoutError:
if json_output:
json_output_response({
"artifact_id": resolved_id,
"status": "timeout",
"error": f"Timed out after {timeout} seconds",
})
json_output_response(
{
"artifact_id": resolved_id,
"status": "timeout",
"error": f"Timed out after {timeout} seconds",
}
)
else:
console.print(f"[red]✗ Timeout after {timeout}s[/red]")
raise SystemExit(1)
@ -463,10 +459,6 @@ def artifact_suggestions(ctx, notebook_id, source_ids, json_output, client_auth)
table.add_row(str(i), suggestion.title, suggestion.description)
console.print(table)
console.print(
'\n[dim]Use the prompt with: notebooklm generate report "<prompt>"[/dim]'
)
console.print('\n[dim]Use the prompt with: notebooklm generate report "<prompt>"[/dim]')
return _run()

View file

@ -13,10 +13,10 @@ from ..client import NotebookLMClient
from ..types import ChatMode
from .helpers import (
console,
require_notebook,
with_client,
get_current_conversation,
require_notebook,
set_current_conversation,
with_client,
)
@ -32,12 +32,8 @@ def register_chat_commands(cli):
default=None,
help="Notebook ID (uses current if not set)",
)
@click.option(
"--conversation-id", "-c", default=None, help="Continue a specific conversation"
)
@click.option(
"--new", "new_conversation", is_flag=True, help="Start a new conversation"
)
@click.option("--conversation-id", "-c", default=None, help="Continue a specific conversation")
@click.option("--new", "new_conversation", is_flag=True, help="Start a new conversation")
@with_client
def ask_cmd(ctx, question, notebook_id, conversation_id, new_conversation, client_auth):
"""Ask a notebook a question.
@ -68,9 +64,7 @@ def register_chat_commands(cli):
if history and history[0]:
last_conv = history[0][-1]
effective_conv_id = (
last_conv[0]
if isinstance(last_conv, list)
else str(last_conv)
last_conv[0] if isinstance(last_conv, list) else str(last_conv)
)
console.print(
f"[dim]Continuing conversation {effective_conv_id[:8]}...[/dim]"
@ -78,9 +72,7 @@ def register_chat_commands(cli):
except Exception:
pass
result = await client.chat.ask(
nb_id, question, conversation_id=effective_conv_id
)
result = await client.chat.ask(nb_id, question, conversation_id=effective_conv_id)
if result.conversation_id:
set_current_conversation(result.conversation_id)
@ -111,9 +103,7 @@ def register_chat_commands(cli):
default=None,
help="Predefined chat mode",
)
@click.option(
"--persona", default=None, help="Custom persona prompt (up to 10,000 chars)"
)
@click.option("--persona", default=None, help="Custom persona prompt (up to 10,000 chars)")
@click.option(
"--response-length",
type=click.Choice(["default", "longer", "shorter"]),
@ -206,6 +196,7 @@ def register_chat_commands(cli):
notebooklm history -n nb123 # Show history for specific notebook
notebooklm history --clear # Clear local cache
"""
async def _run():
async with NotebookLMClient(client_auth) as client:
if clear_cache:
@ -228,9 +219,7 @@ def register_chat_commands(cli):
table.add_column("#", style="dim")
table.add_column("Conversation ID", style="cyan")
for i, conv in enumerate(conversations, 1):
conv_id = (
conv[0] if isinstance(conv, list) and conv else str(conv)
)
conv_id = conv[0] if isinstance(conv, list) and conv else str(conv)
table.add_row(str(i), conv_id)
console.print(table)
console.print(

View file

@ -13,16 +13,16 @@ from typing import Any
import click
from ..auth import AuthTokens, load_auth_from_storage, fetch_tokens
from ..auth import AuthTokens, fetch_tokens, load_auth_from_storage
from ..client import NotebookLMClient
from ..types import Artifact
from .download_helpers import ArtifactDict, artifact_title_to_filename, select_artifact
from .helpers import (
console,
run_async,
require_notebook,
handle_error,
require_notebook,
run_async,
)
from .download_helpers import select_artifact, artifact_title_to_filename, ArtifactDict
@click.group()
@ -166,9 +166,7 @@ async def _download_artifacts_generic(
# Handle --all flag
if download_all:
output_dir = (
Path(output_path) if output_path else Path(default_output_dir)
)
output_dir = Path(output_path) if output_path else Path(default_output_dir)
if dry_run:
return {
@ -199,9 +197,7 @@ async def _download_artifacts_generic(
for i, artifact in enumerate(type_artifacts, 1):
# Progress indicator
if not json_output:
console.print(
f"[dim]Downloading {i}/{total}:[/dim] {artifact['title']}"
)
console.print(f"[dim]Downloading {i}/{total}:[/dim] {artifact['title']}")
# Generate safe name
item_name = artifact_title_to_filename(
@ -220,7 +216,10 @@ async def _download_artifacts_generic(
"id": artifact["id"],
"title": artifact["title"],
"filename": item_name,
**(skip_info or {"status": "skipped", "reason": "conflict resolution failed"}),
**(
skip_info
or {"status": "skipped", "reason": "conflict resolution failed"}
),
}
)
continue
@ -232,9 +231,7 @@ async def _download_artifacts_generic(
# Download
try:
# Download using dispatch
await download_fn(
nb_id, str(item_path), artifact_id=str(artifact["id"])
)
await download_fn(nb_id, str(item_path), artifact_id=str(artifact["id"]))
results.append(
{

View file

@ -1,7 +1,7 @@
"""Helper functions for download commands."""
import re
from typing import Optional, TypedDict
from typing import TypedDict
# Reserve space for " (999)" suffix when handling duplicate filenames
DUPLICATE_SUFFIX_RESERVE = 7
@ -19,8 +19,8 @@ def select_artifact(
artifacts: list[ArtifactDict],
latest: bool = True,
earliest: bool = False,
name: Optional[str] = None,
artifact_id: Optional[str] = None,
name: str | None = None,
artifact_id: str | None = None,
) -> tuple[ArtifactDict, str]:
"""
Select an artifact from a list based on criteria.
@ -106,10 +106,10 @@ def artifact_title_to_filename(
"""
# Sanitize: replace invalid chars with underscore
# Invalid chars: / \ : * ? " < > |
sanitized = re.sub(r'[/\\:*?"<>|]', '_', title)
sanitized = re.sub(r'[/\\:*?"<>|]', "_", title)
# Remove leading/trailing whitespace and dots
sanitized = sanitized.strip('. ')
sanitized = sanitized.strip(". ")
# Fallback for empty titles
if not sanitized:
@ -120,7 +120,7 @@ def artifact_title_to_filename(
# Truncate if too long
if len(sanitized) > effective_max:
sanitized = sanitized[:effective_max].rstrip('. ')
sanitized = sanitized[:effective_max].rstrip(". ")
# Build initial filename
base = sanitized

View file

@ -12,7 +12,7 @@ Commands:
report Generate report
"""
from typing import Any, Optional
from typing import Any
import click
@ -20,23 +20,23 @@ from ..client import NotebookLMClient
from ..types import (
AudioFormat,
AudioLength,
VideoFormat,
VideoStyle,
QuizQuantity,
QuizDifficulty,
InfographicOrientation,
GenerationStatus,
InfographicDetail,
InfographicOrientation,
QuizDifficulty,
QuizQuantity,
ReportFormat,
SlideDeckFormat,
SlideDeckLength,
ReportFormat,
GenerationStatus,
VideoFormat,
VideoStyle,
)
from .helpers import (
console,
json_error_response,
json_output_response,
require_notebook,
with_client,
json_output_response,
json_error_response,
)
@ -48,7 +48,7 @@ async def handle_generation_result(
wait: bool = False,
json_output: bool = False,
timeout: float = 300.0,
) -> Optional[GenerationStatus]:
) -> GenerationStatus | None:
"""Handle generation result with optional waiting and output formatting.
Consolidates common pattern across all generate commands:
@ -98,12 +98,8 @@ async def handle_generation_result(
# Wait for completion if requested
if wait and task_id:
if not json_output:
console.print(
f"[yellow]Generating {artifact_type}...[/yellow] Task: {task_id}"
)
status = await client.artifacts.wait_for_completion(
notebook_id, task_id, timeout=timeout
)
console.print(f"[yellow]Generating {artifact_type}...[/yellow] Task: {task_id}")
status = await client.artifacts.wait_for_completion(notebook_id, task_id, timeout=timeout)
# Output status
_output_generation_status(status, artifact_type, json_output)
@ -111,17 +107,17 @@ async def handle_generation_result(
return status if isinstance(status, GenerationStatus) else None
def _output_generation_status(
status: Any, artifact_type: str, json_output: bool
) -> None:
def _output_generation_status(status: Any, artifact_type: str, json_output: bool) -> None:
"""Output generation status in appropriate format."""
if json_output:
if hasattr(status, "is_complete") and status.is_complete:
json_output_response({
"artifact_id": getattr(status, "task_id", None),
"status": "completed",
"url": getattr(status, "url", None),
})
json_output_response(
{
"artifact_id": getattr(status, "task_id", None),
"status": "completed",
"url": getattr(status, "url", None),
}
)
elif hasattr(status, "is_failed") and status.is_failed:
json_error_response(
"GENERATION_FAILED",
@ -205,9 +201,7 @@ def generate():
default="default",
)
@click.option("--language", default="en")
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@click.option("--json", "json_output", is_flag=True, help="Output as JSON")
@with_client
def generate_audio(
@ -250,9 +244,7 @@ def generate_audio(
audio_format=format_map[audio_format],
audio_length=length_map[audio_length],
)
await handle_generation_result(
client, nb_id, result, "audio", wait, json_output
)
await handle_generation_result(client, nb_id, result, "audio", wait, json_output)
return _run()
@ -290,9 +282,7 @@ def generate_audio(
default="auto",
)
@click.option("--language", default="en")
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@click.option("--json", "json_output", is_flag=True, help="Output as JSON")
@with_client
def generate_video(
@ -358,9 +348,7 @@ def generate_video(
default="default",
)
@click.option("--language", default="en")
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@with_client
def generate_slide_deck(
ctx, description, notebook_id, deck_format, deck_length, language, wait, client_auth
@ -391,9 +379,7 @@ def generate_slide_deck(
slide_format=format_map[deck_format],
slide_length=length_map[deck_length],
)
await handle_generation_result(
client, nb_id, result, "slide deck", wait
)
await handle_generation_result(client, nb_id, result, "slide deck", wait)
return _run()
@ -407,15 +393,9 @@ def generate_slide_deck(
default=None,
help="Notebook ID (uses current if not set)",
)
@click.option(
"--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard"
)
@click.option(
"--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium"
)
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard")
@click.option("--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium")
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@with_client
def generate_quiz(ctx, description, notebook_id, quantity, difficulty, wait, client_auth):
"""Generate quiz.
@ -459,15 +439,9 @@ def generate_quiz(ctx, description, notebook_id, quantity, difficulty, wait, cli
default=None,
help="Notebook ID (uses current if not set)",
)
@click.option(
"--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard"
)
@click.option(
"--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium"
)
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard")
@click.option("--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium")
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@with_client
def generate_flashcards(ctx, description, notebook_id, quantity, difficulty, wait, client_auth):
"""Generate flashcards.
@ -522,9 +496,7 @@ def generate_flashcards(ctx, description, notebook_id, quantity, difficulty, wai
default="standard",
)
@click.option("--language", default="en")
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@with_client
def generate_infographic(
ctx, description, notebook_id, orientation, detail, language, wait, client_auth
@ -572,9 +544,7 @@ def generate_infographic(
help="Notebook ID (uses current if not set)",
)
@click.option("--language", default="en")
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@with_client
def generate_data_table(ctx, description, notebook_id, language, wait, client_auth):
"""Generate data table.
@ -621,9 +591,7 @@ def generate_mind_map(ctx, notebook_id, client_auth):
mind_map = result.get("mind_map", {})
if isinstance(mind_map, dict):
console.print(f" Root: {mind_map.get('name', '-')}")
console.print(
f" Children: {len(mind_map.get('children', []))} nodes"
)
console.print(f" Children: {len(mind_map.get('children', []))} nodes")
else:
console.print(result)
else:
@ -648,9 +616,7 @@ def generate_mind_map(ctx, notebook_id, client_auth):
default=None,
help="Notebook ID (uses current if not set)",
)
@click.option(
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
)
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
@with_client
def generate_report_cmd(ctx, description, report_format, notebook_id, wait, client_auth):
"""Generate a report (briefing doc, study guide, blog post, or custom).

View file

@ -20,24 +20,28 @@ class SectionedGroup(click.Group):
"""
# Regular commands - show help text
command_sections = OrderedDict([
("Session", ["login", "use", "status", "clear"]),
("Notebooks", ["list", "create", "delete", "rename", "share", "summary"]),
("Chat", ["ask", "configure", "history"]),
])
command_sections = OrderedDict(
[
("Session", ["login", "use", "status", "clear"]),
("Notebooks", ["list", "create", "delete", "rename", "share", "summary"]),
("Chat", ["ask", "configure", "history"]),
]
)
# Command groups - show sorted subcommands instead of help text
command_groups = OrderedDict([
("Command Groups (use: notebooklm <group> <command>)",
["source", "artifact", "note", "research"]),
("Artifact Actions (use: notebooklm <action> <type>)",
["generate", "download"]),
])
command_groups = OrderedDict(
[
(
"Command Groups (use: notebooklm <group> <command>)",
["source", "artifact", "note", "research"],
),
("Artifact Actions (use: notebooklm <action> <type>)", ["generate", "download"]),
]
)
def format_commands(self, ctx, formatter):
"""Override to display commands in sections."""
commands = {name: self.get_command(ctx, name)
for name in self.list_commands(ctx)}
commands = {name: self.get_command(ctx, name) for name in self.list_commands(ctx)}
# Regular command sections (show help text)
for section, cmd_names in self.command_sections.items():
@ -67,9 +71,13 @@ class SectionedGroup(click.Group):
# Safety net: show any commands not in any section
all_listed = set(sum(self.command_sections.values(), []))
all_listed |= set(sum(self.command_groups.values(), []))
unlisted = [(n, c) for n, c in commands.items()
if n not in all_listed and c is not None and not c.hidden]
unlisted = [
(n, c)
for n, c in commands.items()
if n not in all_listed and c is not None and not c.hidden
]
if unlisted:
with formatter.section("Other"):
formatter.write_dl([(n, c.get_short_help_str(limit=formatter.width))
for n, c in unlisted])
formatter.write_dl(
[(n, c.get_short_help_str(limit=formatter.width)) for n, c in unlisted]
)

View file

@ -18,10 +18,10 @@ from rich.console import Console
from ..auth import (
AuthTokens,
load_auth_from_storage,
fetch_tokens,
load_auth_from_storage,
)
from ..paths import get_context_path, get_browser_profile_dir
from ..paths import get_browser_profile_dir, get_context_path
console = Console()
@ -116,7 +116,7 @@ def get_current_notebook() -> str | None:
try:
data = json.loads(context_file.read_text())
return data.get("notebook_id")
except (json.JSONDecodeError, IOError):
except (OSError, json.JSONDecodeError):
return None
@ -154,7 +154,7 @@ def get_current_conversation() -> str | None:
try:
data = json.loads(context_file.read_text())
return data.get("conversation_id")
except (json.JSONDecodeError, IOError):
except (OSError, json.JSONDecodeError):
return None
@ -170,7 +170,7 @@ def set_current_conversation(conversation_id: str | None):
elif "conversation_id" in data:
del data["conversation_id"]
context_file.write_text(json.dumps(data, indent=2))
except (json.JSONDecodeError, IOError):
except (OSError, json.JSONDecodeError):
pass
@ -247,8 +247,7 @@ async def _resolve_partial_id(
return partial_id
items = await list_fn()
matches = [item for item in items
if item.id.lower().startswith(partial_id.lower())]
matches = [item for item in items if item.id.lower().startswith(partial_id.lower())]
if len(matches) == 1:
if matches[0].id != partial_id:
@ -315,13 +314,9 @@ def handle_error(e: Exception):
def handle_auth_error(json_output: bool = False):
"""Handle authentication errors."""
if json_output:
json_error_response(
"AUTH_REQUIRED", "Auth not found. Run 'notebooklm login' first."
)
json_error_response("AUTH_REQUIRED", "Auth not found. Run 'notebooklm login' first.")
else:
console.print(
"[red]Not logged in. Run 'notebooklm login' first.[/red]"
)
console.print("[red]Not logged in. Run 'notebooklm login' first.[/red]")
raise SystemExit(1)
@ -358,6 +353,7 @@ def with_client(f):
Returns:
Decorated function with Click pass_context
"""
@wraps(f)
@click.pass_context
def wrapper(ctx, *args, **kwargs):
@ -374,6 +370,7 @@ def with_client(f):
json_error_response("ERROR", str(e))
else:
handle_error(e)
return wrapper
@ -389,9 +386,7 @@ def json_output_response(data: dict) -> None:
def json_error_response(code: str, message: str) -> None:
"""Print JSON error and exit."""
console.print(
json.dumps({"error": True, "code": code, "message": message}, indent=2)
)
console.print(json.dumps({"error": True, "code": code, "message": message}, indent=2))
raise SystemExit(1)

View file

@ -68,7 +68,7 @@ def note_list(ctx, notebook_id, client_auth):
table.add_row(
n.id,
n.title or "Untitled",
preview + "..." if len(n.content or "") > 50 else preview
preview + "..." if len(n.content or "") > 50 else preview,
)
console.print(table)

View file

@ -14,13 +14,13 @@ from rich.table import Table
from ..client import NotebookLMClient
from .helpers import (
console,
require_notebook,
with_client,
json_output_response,
get_current_notebook,
clear_context,
console,
get_current_notebook,
json_output_response,
require_notebook,
resolve_notebook_id,
with_client,
)
@ -32,6 +32,7 @@ def register_notebook_commands(cli):
@with_client
def list_cmd(ctx, json_output, client_auth):
"""List all notebooks."""
async def _run():
async with NotebookLMClient(client_auth) as client:
notebooks = await client.notebooks.list()
@ -44,9 +45,7 @@ def register_notebook_commands(cli):
"id": nb.id,
"title": nb.title,
"is_owner": nb.is_owner,
"created_at": nb.created_at.isoformat()
if nb.created_at
else None,
"created_at": nb.created_at.isoformat() if nb.created_at else None,
}
for i, nb in enumerate(notebooks, 1)
],
@ -76,6 +75,7 @@ def register_notebook_commands(cli):
@with_client
def create_cmd(ctx, title, json_output, client_auth):
"""Create a new notebook."""
async def _run():
async with NotebookLMClient(client_auth) as client:
nb = await client.notebooks.create(title)
@ -85,9 +85,7 @@ def register_notebook_commands(cli):
"notebook": {
"id": nb.id,
"title": nb.title,
"created_at": nb.created_at.isoformat()
if nb.created_at
else None,
"created_at": nb.created_at.isoformat() if nb.created_at else None,
}
}
json_output_response(data)

View file

@ -13,9 +13,9 @@ from rich.table import Table
from ..client import NotebookLMClient
from .helpers import (
console,
json_output_response,
require_notebook,
with_client,
json_output_response,
)
@ -102,9 +102,7 @@ def research_status(ctx, notebook_id, json_output, client_auth):
if summary:
console.print(f"\n[bold]Summary:[/bold]\n{summary[:500]}")
console.print(
"\n[dim]Use 'research wait --import-all' to import sources[/dim]"
)
console.print("\n[dim]Use 'research wait --import-all' to import sources[/dim]")
else:
console.print(f"[yellow]Status: {status_val}[/yellow]")
@ -134,9 +132,7 @@ def research_status(ctx, notebook_id, json_output, client_auth):
@click.option("--import-all", is_flag=True, help="Import all found sources when done")
@click.option("--json", "json_output", is_flag=True, help="Output as JSON")
@with_client
def research_wait(
ctx, notebook_id, timeout, interval, import_all, json_output, client_auth
):
def research_wait(ctx, notebook_id, timeout, interval, import_all, json_output, client_auth):
"""Wait for research to complete.
Blocks until research is completed or timeout is reached.
@ -195,9 +191,7 @@ def research_wait(
"sources": sources,
}
if import_all and sources and task_id:
imported = await client.research.import_sources(
nb_id, task_id, sources
)
imported = await client.research.import_sources(nb_id, task_id, sources)
result["imported"] = len(imported)
result["imported_sources"] = imported
json_output_response(result)
@ -207,9 +201,7 @@ def research_wait(
if import_all and sources and task_id:
with console.status("Importing sources..."):
imported = await client.research.import_sources(
nb_id, task_id, sources
)
imported = await client.research.import_sources(nb_id, task_id, sources)
console.print(f"[green]Imported {len(imported)} sources[/green]")
return _run()

View file

@ -17,20 +17,20 @@ from rich.table import Table
from ..auth import AuthTokens
from ..client import NotebookLMClient
from ..paths import (
get_storage_path,
get_context_path,
get_browser_profile_dir,
get_context_path,
get_path_info,
get_storage_path,
)
from .helpers import (
clear_context,
console,
run_async,
get_client,
get_current_notebook,
set_current_notebook,
clear_context,
json_output_response,
resolve_notebook_id,
run_async,
set_current_notebook,
)
@ -216,7 +216,9 @@ def register_session_commands(cli):
# Show if NOTEBOOKLM_AUTH_JSON is set
if os.environ.get("NOTEBOOKLM_AUTH_JSON"):
console.print("[yellow]Note: NOTEBOOKLM_AUTH_JSON is set (inline auth active)[/yellow]\n")
console.print(
"[yellow]Note: NOTEBOOKLM_AUTH_JSON is set (inline auth active)[/yellow]\n"
)
console.print(table)
return
@ -254,11 +256,9 @@ def register_session_commands(cli):
if conversation_id:
table.add_row("Conversation", conversation_id)
else:
table.add_row(
"Conversation", "[dim]None (will auto-select on next ask)[/dim]"
)
table.add_row("Conversation", "[dim]None (will auto-select on next ask)[/dim]")
console.print(table)
except (json.JSONDecodeError, IOError):
except (OSError, json.JSONDecodeError):
if json_output:
json_data = {
"has_context": True,

View file

@ -6,19 +6,17 @@ Commands for managing the Claude Code skill integration.
import re
from importlib import resources
from pathlib import Path
from typing import Optional
import click
from .helpers import console
# Skill paths
SKILL_DEST_DIR = Path.home() / ".claude" / "skills" / "notebooklm"
SKILL_DEST = SKILL_DEST_DIR / "SKILL.md"
def get_skill_source_content() -> Optional[str]:
def get_skill_source_content() -> str | None:
"""Read the skill source file from package data."""
try:
# Python 3.9+ way to read package data (use / operator for path traversal)
@ -31,6 +29,7 @@ def get_package_version() -> str:
"""Get the current package version."""
try:
from .. import __version__
return __version__
except ImportError:
return "unknown"
@ -44,7 +43,7 @@ def get_skill_version(skill_path: Path) -> str | None:
with open(skill_path) as f:
content = f.read(500) # Read first 500 chars
match = re.search(r'notebooklm-py v([\d.]+)', content)
match = re.search(r"notebooklm-py v([\d.]+)", content)
return match.group(1) if match else None
@ -116,7 +115,9 @@ def status():
if skill_version and skill_version != cli_version:
console.print("")
console.print("[yellow]Version mismatch![/yellow] Run [cyan]notebooklm skill install[/cyan] to update.")
console.print(
"[yellow]Version mismatch![/yellow] Run [cyan]notebooklm skill install[/cyan] to update."
)
@skill.command()

View file

@ -22,11 +22,11 @@ from rich.table import Table
from ..client import NotebookLMClient
from .helpers import (
console,
get_source_type_display,
json_output_response,
require_notebook,
resolve_source_id,
with_client,
json_output_response,
get_source_type_display,
)
@ -85,9 +85,7 @@ def source_list(ctx, notebook_id, json_output, client_auth):
"title": src.title,
"type": src.source_type,
"url": src.url,
"created_at": src.created_at.isoformat()
if src.created_at
else None,
"created_at": src.created_at.isoformat() if src.created_at else None,
}
for i, src in enumerate(sources, 1)
],
@ -104,9 +102,7 @@ def source_list(ctx, notebook_id, json_output, client_auth):
for src in sources:
type_display = get_source_type_display(src.source_type)
created = (
src.created_at.strftime("%Y-%m-%d %H:%M") if src.created_at else "-"
)
created = src.created_at.strftime("%Y-%m-%d %H:%M") if src.created_at else "-"
table.add_row(src.id, src.title or "-", type_display, created)
console.print(table)
@ -186,9 +182,7 @@ def source_add(ctx, content, notebook_id, source_type, title, mime_type, json_ou
async def _run():
async with NotebookLMClient(client_auth) as client:
if detected_type == "url":
src = await client.sources.add_url(nb_id, content)
elif detected_type == "youtube":
if detected_type == "url" or detected_type == "youtube":
src = await client.sources.add_url(nb_id, content)
elif detected_type == "text":
text_content = file_content if file_content is not None else content
@ -242,9 +236,7 @@ def source_get(ctx, source_id, notebook_id, client_auth):
if src:
console.print(f"[bold cyan]Source:[/bold cyan] {src.id}")
console.print(f"[bold]Title:[/bold] {src.title}")
console.print(
f"[bold]Type:[/bold] {get_source_type_display(src.source_type)}"
)
console.print(f"[bold]Type:[/bold] {get_source_type_display(src.source_type)}")
if src.url:
console.print(f"[bold]URL:[/bold] {src.url}")
if src.created_at:
@ -443,9 +435,7 @@ def source_add_research(
async def _run():
async with NotebookLMClient(client_auth) as client:
console.print(
f"[yellow]Starting {mode} research on {search_source}...[/yellow]"
)
console.print(f"[yellow]Starting {mode} research on {search_source}...[/yellow]")
result = await client.research.start(nb_id, query, search_source, mode)
if not result:
console.print("[red]Research failed to start[/red]")
@ -479,9 +469,7 @@ def source_add_research(
console.print(f"\n[green]Found {len(sources)} sources[/green]")
if import_all and sources and task_id:
imported = await client.research.import_sources(
nb_id, task_id, sources
)
imported = await client.research.import_sources(nb_id, task_id, sources)
console.print(f"[green]Imported {len(imported)} sources[/green]")
else:
console.print(f"[yellow]Status: {status.get('status', 'unknown')}[/yellow]")
@ -634,7 +622,7 @@ def source_wait(ctx, source_id, notebook_id, timeout, json_output, client_auth):
notebooklm source add https://example.com
# Subagent runs: notebooklm source wait <source_id>
"""
from ..types import SourceProcessingError, SourceTimeoutError, SourceNotFoundError
from ..types import SourceNotFoundError, SourceProcessingError, SourceTimeoutError
nb_id = require_notebook(notebook_id)

View file

@ -20,16 +20,15 @@ Example:
"""
from pathlib import Path
from typing import Optional
from .auth import AuthTokens
from ._core import ClientCore, DEFAULT_TIMEOUT
from ._notebooks import NotebooksAPI
from ._sources import SourcesAPI
from ._artifacts import ArtifactsAPI
from ._chat import ChatAPI
from ._research import ResearchAPI
from ._core import DEFAULT_TIMEOUT, ClientCore
from ._notebooks import NotebooksAPI
from ._notes import NotesAPI
from ._research import ResearchAPI
from ._sources import SourcesAPI
from .auth import AuthTokens
class NotebookLMClient:
@ -102,7 +101,7 @@ class NotebookLMClient:
@classmethod
async def from_storage(
cls, path: Optional[str] = None, timeout: float = DEFAULT_TIMEOUT
cls, path: str | None = None, timeout: float = DEFAULT_TIMEOUT
) -> "NotebookLMClient":
"""Create a client from Playwright storage state file.
@ -146,9 +145,7 @@ class NotebookLMClient:
# Check for redirect to login page
final_url = str(response.url)
if "accounts.google.com" in final_url:
raise ValueError(
"Authentication expired. Run 'notebooklm login' to re-authenticate."
)
raise ValueError("Authentication expired. Run 'notebooklm login' to re-authenticate.")
# Extract SNlM0e (CSRF token) - REQUIRED
csrf_match = re.search(r'"SNlM0e":"([^"]+)"', response.text)

View file

@ -32,17 +32,17 @@ from .auth import DEFAULT_STORAGE_PATH
# Import command groups from cli package
from .cli import (
source,
artifact,
generate,
download,
generate,
note,
skill,
research,
register_chat_commands,
register_notebook_commands,
# Register functions for top-level commands
register_session_commands,
register_notebook_commands,
register_chat_commands,
research,
skill,
source,
)
from .cli.grouped import SectionedGroup

View file

@ -1,36 +1,36 @@
"""RPC protocol implementation for NotebookLM batchexecute API."""
from .decoder import (
RPCError,
collect_rpc_ids,
decode_response,
extract_rpc_result,
parse_chunked_response,
strip_anti_xssi,
)
from .encoder import build_request_body, encode_rpc_request
from .types import (
RPCMethod,
BATCHEXECUTE_URL,
QUERY_URL,
UPLOAD_URL,
StudioContentType,
ArtifactStatus,
AudioFormat,
AudioLength,
VideoFormat,
VideoStyle,
QuizQuantity,
QuizDifficulty,
InfographicOrientation,
InfographicDetail,
SlideDeckFormat,
SlideDeckLength,
ReportFormat,
ChatGoal,
ChatResponseLength,
DriveMimeType,
ExportType,
)
from .encoder import encode_rpc_request, build_request_body
from .decoder import (
strip_anti_xssi,
parse_chunked_response,
extract_rpc_result,
collect_rpc_ids,
decode_response,
RPCError,
InfographicDetail,
InfographicOrientation,
QuizDifficulty,
QuizQuantity,
ReportFormat,
RPCMethod,
SlideDeckFormat,
SlideDeckLength,
StudioContentType,
VideoFormat,
VideoStyle,
)
__all__ = [

View file

@ -3,7 +3,7 @@
import json
import logging
import re
from typing import Any, Optional
from typing import Any
logger = logging.getLogger(__name__)
@ -14,9 +14,9 @@ class RPCError(Exception):
def __init__(
self,
message: str,
rpc_id: Optional[str] = None,
code: Optional[Any] = None,
found_ids: Optional[list[str]] = None,
rpc_id: str | None = None,
code: Any | None = None,
found_ids: list[str] | None = None,
):
self.rpc_id = rpc_id
self.code = code

View file

@ -1,7 +1,7 @@
"""Encode RPC requests for NotebookLM batchexecute API."""
import json
from typing import Any, Optional
from typing import Any
from urllib.parse import quote
from .types import RPCMethod
@ -33,8 +33,8 @@ def encode_rpc_request(method: RPCMethod, params: list[Any]) -> list:
def build_request_body(
rpc_request: list,
csrf_token: Optional[str] = None,
session_id: Optional[str] = None,
csrf_token: str | None = None,
session_id: str | None = None,
) -> str:
"""
Build form-encoded request body for batchexecute.
@ -67,8 +67,8 @@ def build_request_body(
def build_url_params(
rpc_method: RPCMethod,
source_path: str = "/",
session_id: Optional[str] = None,
bl: Optional[str] = None,
session_id: str | None = None,
bl: str | None = None,
) -> dict[str, str]:
"""
Build URL query parameters for batchexecute request.

View file

@ -2,7 +2,6 @@
from enum import Enum
# NotebookLM API endpoints
BATCHEXECUTE_URL = "https://notebooklm.google.com/_/LabsTailwindUi/data/batchexecute"
QUERY_URL = "https://notebooklm.google.com/_/LabsTailwindUi/data/google.internal.labs.tailwind.orchestration.v1.LabsTailwindOrchestrationService/GenerateFreeFormStreamed"
@ -93,7 +92,9 @@ class StudioContentType(int, Enum):
"""
AUDIO = 1
REPORT = 2 # Includes: Briefing Doc, Study Guide, Blog Post, White Paper, Research Proposal, etc.
REPORT = (
2 # Includes: Briefing Doc, Study Guide, Blog Post, White Paper, Research Proposal, etc.
)
VIDEO = 3
QUIZ = 4 # Also used for flashcards
QUIZ_FLASHCARD = 4 # Alias for backward compatibility

View file

@ -14,23 +14,23 @@ from typing import Any, Optional
# Re-export enums from rpc/types.py for convenience
from .rpc.types import (
StudioContentType,
AudioFormat,
AudioLength,
VideoFormat,
VideoStyle,
QuizQuantity,
QuizDifficulty,
InfographicOrientation,
InfographicDetail,
SlideDeckFormat,
SlideDeckLength,
ReportFormat,
ChatGoal,
ChatResponseLength,
DriveMimeType,
ExportType,
InfographicDetail,
InfographicOrientation,
QuizDifficulty,
QuizQuantity,
ReportFormat,
SlideDeckFormat,
SlideDeckLength,
SourceStatus,
StudioContentType,
VideoFormat,
VideoStyle,
)
__all__ = [
@ -99,7 +99,7 @@ class Notebook:
id: str
title: str
created_at: Optional[datetime] = None
created_at: datetime | None = None
sources_count: int = 0
is_owner: bool = True
@ -197,7 +197,7 @@ class SourceTimeoutError(SourceError):
last_status: The last observed status before timeout.
"""
def __init__(self, source_id: str, timeout: float, last_status: Optional[int] = None):
def __init__(self, source_id: str, timeout: float, last_status: int | None = None):
self.source_id = source_id
self.timeout = timeout
self.last_status = last_status
@ -231,10 +231,10 @@ class Source:
"""
id: str
title: Optional[str] = None
url: Optional[str] = None
title: str | None = None
url: str | None = None
source_type: str = "text"
created_at: Optional[datetime] = None
created_at: datetime | None = None
status: int = SourceStatus.READY # Default to READY (2)
@property
@ -253,9 +253,7 @@ class Source:
return self.status == SourceStatus.ERROR
@classmethod
def from_api_response(
cls, data: list[Any], notebook_id: Optional[str] = None
) -> "Source":
def from_api_response(cls, data: list[Any], notebook_id: str | None = None) -> "Source":
"""Parse source data from various API response formats.
The API returns different structures for different operations:
@ -294,12 +292,7 @@ class Source:
if len(entry[2]) > 7 and isinstance(entry[2][7], list):
url = entry[2][7][0] if entry[2][7] else None
return cls(
id=str(source_id),
title=title,
url=url,
source_type="text"
)
return cls(id=str(source_id), title=title, url=url, source_type="text")
# Deeply nested: continue with URL extraction
url = None
@ -309,14 +302,14 @@ class Source:
if isinstance(url_list, list) and len(url_list) > 0:
url = url_list[0]
if not url and len(entry[2]) > 0:
if isinstance(entry[2][0], str) and entry[2][0].startswith('http'):
if isinstance(entry[2][0], str) and entry[2][0].startswith("http"):
url = entry[2][0]
# Determine source type
source_type = "text"
if url:
source_type = "youtube" if "youtube.com" in url or "youtu.be" in url else "url"
elif title and (title.endswith('.pdf') or title.endswith('.txt')):
elif title and (title.endswith(".pdf") or title.endswith(".txt")):
source_type = "text_file"
return cls(
@ -350,9 +343,9 @@ class Artifact:
title: str
artifact_type: int # StudioContentType enum value
status: int # 1=processing, 3=completed
created_at: Optional[datetime] = None
url: Optional[str] = None
variant: Optional[int] = None # For type 4: 1=flashcards, 2=quiz
created_at: datetime | None = None
url: str | None = None
variant: int | None = None # For type 4: 1=flashcards, 2=quiz
@classmethod
def from_api_response(cls, data: list[Any]) -> "Artifact":
@ -469,7 +462,7 @@ class Artifact:
return self.artifact_type == 4 and self.variant == 1
@property
def report_subtype(self) -> Optional[str]:
def report_subtype(self) -> str | None:
"""Get the report subtype for type 2 artifacts.
Returns:
@ -493,10 +486,10 @@ class GenerationStatus:
task_id: str
status: str # "pending", "in_progress", "completed", "failed"
url: Optional[str] = None
error: Optional[str] = None
error_code: Optional[str] = None # e.g., "USER_DISPLAYABLE_ERROR" for rate limits
metadata: Optional[dict[str, Any]] = None
url: str | None = None
error: str | None = None
error_code: str | None = None # e.g., "USER_DISPLAYABLE_ERROR" for rate limits
metadata: dict[str, Any] | None = None
@property
def is_complete(self) -> bool:
@ -578,7 +571,7 @@ class Note:
notebook_id: str
title: str
content: str
created_at: Optional[datetime] = None
created_at: datetime | None = None
@classmethod
def from_api_response(cls, data: list[Any], notebook_id: str) -> "Note":

View file

@ -1,8 +1,8 @@
"""Shared test fixtures."""
import pytest
import json
from typing import Union
import pytest
from notebooklm.rpc import RPCMethod
@ -76,7 +76,7 @@ def build_rpc_response():
data: The response data to encode.
"""
def _build(rpc_id: Union[RPCMethod, str], data) -> str:
def _build(rpc_id: RPCMethod | str, data) -> str:
# Convert RPCMethod to string value if needed
rpc_id_str = rpc_id.value if isinstance(rpc_id, RPCMethod) else rpc_id
inner = json.dumps(data)

View file

@ -2,27 +2,28 @@
import os
import warnings
import pytest
import httpx
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import AsyncGenerator
import httpx
import pytest
# Load .env file if python-dotenv is available
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass # python-dotenv not installed, rely on shell environment
from notebooklm import NotebookLMClient
from notebooklm.auth import (
load_auth_from_storage,
AuthTokens,
extract_csrf_from_html,
extract_session_id_from_html,
AuthTokens,
load_auth_from_storage,
)
from notebooklm.paths import get_home_dir
from notebooklm import NotebookLMClient
# =============================================================================
# Constants
@ -187,8 +188,6 @@ async def cleanup_notebooks(created_notebooks, auth_tokens):
warnings.warn(f"Failed to cleanup notebook {nb_id}: {e}")
# =============================================================================
# Notebook Fixtures
# =============================================================================
@ -203,6 +202,7 @@ async def temp_notebook(client, created_notebooks, cleanup_notebooks):
"""
import asyncio
from uuid import uuid4
notebook = await client.notebooks.create(f"Test-{uuid4().hex[:8]}")
created_notebooks.append(notebook.id)
@ -332,7 +332,7 @@ async def _cleanup_generation_notebook(client: NotebookLMClient, notebook_id: st
notes = await client.notes.list(notebook_id)
for note in notes:
# Skip if no id or if it's a pinned system note
if note.id and not getattr(note, 'pinned', False):
if note.id and not getattr(note, "pinned", False):
try:
await client.notes.delete(notebook_id, note.id)
except Exception:
@ -432,5 +432,3 @@ async def generation_notebook_id(client):
await client.notebooks.delete(notebook_id)
except Exception as e:
warnings.warn(f"Failed to delete generation notebook {notebook_id}: {e}")

View file

@ -8,10 +8,13 @@ Generation tests are in test_generation.py. This file contains:
"""
import asyncio
import pytest
from .conftest import requires_auth, assert_generation_started
from notebooklm import Artifact, ReportSuggestion
from .conftest import assert_generation_started, requires_auth
@requires_auth
class TestArtifactRetrieval:

View file

@ -1,31 +1,31 @@
import os
import tempfile
import pytest
from .conftest import requires_auth
from notebooklm import Artifact
import pytest
from .conftest import requires_auth
# Magic bytes for file type verification
PNG_MAGIC = b'\x89PNG\r\n\x1a\n'
PDF_MAGIC = b'%PDF'
MP4_FTYP = b'ftyp' # At offset 4
PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
PDF_MAGIC = b"%PDF"
MP4_FTYP = b"ftyp" # At offset 4
def is_png(path: str) -> bool:
"""Check if file is a valid PNG by magic bytes."""
with open(path, 'rb') as f:
with open(path, "rb") as f:
return f.read(8) == PNG_MAGIC
def is_pdf(path: str) -> bool:
"""Check if file is a valid PDF by magic bytes."""
with open(path, 'rb') as f:
with open(path, "rb") as f:
return f.read(4) == PDF_MAGIC
def is_mp4(path: str) -> bool:
"""Check if file is a valid MP4 by magic bytes."""
with open(path, 'rb') as f:
with open(path, "rb") as f:
header = f.read(12)
# MP4 has 'ftyp' at offset 4
return len(header) >= 8 and header[4:8] == MP4_FTYP
@ -106,7 +106,9 @@ class TestDownloadSlideDeck:
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "slides.pdf")
try:
result = await client.artifacts.download_slide_deck(read_only_notebook_id, output_path)
result = await client.artifacts.download_slide_deck(
read_only_notebook_id, output_path
)
assert result == output_path
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0

View file

@ -1,7 +1,9 @@
import os
import tempfile
import pytest
from pathlib import Path
import pytest
from .conftest import requires_auth

View file

@ -12,37 +12,35 @@ Notebook lifecycle:
"""
import pytest
from .conftest import requires_auth, assert_generation_started
from notebooklm import (
AudioFormat,
AudioLength,
VideoFormat,
VideoStyle,
QuizQuantity,
QuizDifficulty,
InfographicOrientation,
InfographicDetail,
InfographicOrientation,
QuizDifficulty,
QuizQuantity,
SlideDeckFormat,
SlideDeckLength,
VideoFormat,
VideoStyle,
)
from .conftest import assert_generation_started, requires_auth
@requires_auth
class TestAudioGeneration:
"""Audio generation tests."""
@pytest.mark.asyncio
async def test_generate_audio_default(
self, client, generation_notebook_id
):
async def test_generate_audio_default(self, client, generation_notebook_id):
"""Test audio generation with true defaults."""
result = await client.artifacts.generate_audio(generation_notebook_id)
assert_generation_started(result)
@pytest.mark.asyncio
async def test_generate_audio_brief(
self, client, generation_notebook_id
):
async def test_generate_audio_brief(self, client, generation_notebook_id):
"""Test audio generation with non-default format to verify param encoding."""
result = await client.artifacts.generate_audio(
generation_notebook_id,
@ -52,9 +50,7 @@ class TestAudioGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_audio_deep_dive_long(
self, client, generation_notebook_id
):
async def test_generate_audio_deep_dive_long(self, client, generation_notebook_id):
result = await client.artifacts.generate_audio(
generation_notebook_id,
audio_format=AudioFormat.DEEP_DIVE,
@ -64,9 +60,7 @@ class TestAudioGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_audio_brief_short(
self, client, generation_notebook_id
):
async def test_generate_audio_brief_short(self, client, generation_notebook_id):
result = await client.artifacts.generate_audio(
generation_notebook_id,
audio_format=AudioFormat.BRIEF,
@ -76,9 +70,7 @@ class TestAudioGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_audio_critique(
self, client, generation_notebook_id
):
async def test_generate_audio_critique(self, client, generation_notebook_id):
result = await client.artifacts.generate_audio(
generation_notebook_id,
audio_format=AudioFormat.CRITIQUE,
@ -87,9 +79,7 @@ class TestAudioGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_audio_debate(
self, client, generation_notebook_id
):
async def test_generate_audio_debate(self, client, generation_notebook_id):
result = await client.artifacts.generate_audio(
generation_notebook_id,
audio_format=AudioFormat.DEBATE,
@ -98,9 +88,7 @@ class TestAudioGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_audio_with_language(
self, client, generation_notebook_id
):
async def test_generate_audio_with_language(self, client, generation_notebook_id):
result = await client.artifacts.generate_audio(
generation_notebook_id,
language="en",
@ -113,9 +101,7 @@ class TestVideoGeneration:
"""Video generation tests."""
@pytest.mark.asyncio
async def test_generate_video_default(
self, client, generation_notebook_id
):
async def test_generate_video_default(self, client, generation_notebook_id):
"""Test video generation with non-default style to verify param encoding."""
result = await client.artifacts.generate_video(
generation_notebook_id,
@ -125,9 +111,7 @@ class TestVideoGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_video_explainer_anime(
self, client, generation_notebook_id
):
async def test_generate_video_explainer_anime(self, client, generation_notebook_id):
result = await client.artifacts.generate_video(
generation_notebook_id,
video_format=VideoFormat.EXPLAINER,
@ -137,9 +121,7 @@ class TestVideoGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_video_brief_whiteboard(
self, client, generation_notebook_id
):
async def test_generate_video_brief_whiteboard(self, client, generation_notebook_id):
result = await client.artifacts.generate_video(
generation_notebook_id,
video_format=VideoFormat.BRIEF,
@ -149,9 +131,7 @@ class TestVideoGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_video_with_instructions(
self, client, generation_notebook_id
):
async def test_generate_video_with_instructions(self, client, generation_notebook_id):
result = await client.artifacts.generate_video(
generation_notebook_id,
video_format=VideoFormat.EXPLAINER,
@ -162,9 +142,7 @@ class TestVideoGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_video_kawaii_style(
self, client, generation_notebook_id
):
async def test_generate_video_kawaii_style(self, client, generation_notebook_id):
result = await client.artifacts.generate_video(
generation_notebook_id,
video_style=VideoStyle.KAWAII,
@ -173,9 +151,7 @@ class TestVideoGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_video_watercolor_style(
self, client, generation_notebook_id
):
async def test_generate_video_watercolor_style(self, client, generation_notebook_id):
result = await client.artifacts.generate_video(
generation_notebook_id,
video_style=VideoStyle.WATERCOLOR,
@ -184,9 +160,7 @@ class TestVideoGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_video_auto_style(
self, client, generation_notebook_id
):
async def test_generate_video_auto_style(self, client, generation_notebook_id):
result = await client.artifacts.generate_video(
generation_notebook_id,
video_style=VideoStyle.AUTO_SELECT,
@ -199,9 +173,7 @@ class TestQuizGeneration:
"""Quiz generation tests."""
@pytest.mark.asyncio
async def test_generate_quiz_default(
self, client, generation_notebook_id
):
async def test_generate_quiz_default(self, client, generation_notebook_id):
"""Test quiz generation with non-default difficulty to verify param encoding."""
result = await client.artifacts.generate_quiz(
generation_notebook_id,
@ -211,9 +183,7 @@ class TestQuizGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_quiz_with_options(
self, client, generation_notebook_id
):
async def test_generate_quiz_with_options(self, client, generation_notebook_id):
result = await client.artifacts.generate_quiz(
generation_notebook_id,
quantity=QuizQuantity.MORE,
@ -224,9 +194,7 @@ class TestQuizGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_quiz_fewer_easy(
self, client, generation_notebook_id
):
async def test_generate_quiz_fewer_easy(self, client, generation_notebook_id):
result = await client.artifacts.generate_quiz(
generation_notebook_id,
quantity=QuizQuantity.FEWER,
@ -240,9 +208,7 @@ class TestFlashcardsGeneration:
"""Flashcards generation tests."""
@pytest.mark.asyncio
async def test_generate_flashcards_default(
self, client, generation_notebook_id
):
async def test_generate_flashcards_default(self, client, generation_notebook_id):
"""Test flashcards generation with non-default quantity to verify param encoding."""
result = await client.artifacts.generate_flashcards(
generation_notebook_id,
@ -252,9 +218,7 @@ class TestFlashcardsGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_flashcards_with_options(
self, client, generation_notebook_id
):
async def test_generate_flashcards_with_options(self, client, generation_notebook_id):
result = await client.artifacts.generate_flashcards(
generation_notebook_id,
quantity=QuizQuantity.STANDARD,
@ -269,9 +233,7 @@ class TestInfographicGeneration:
"""Infographic generation tests."""
@pytest.mark.asyncio
async def test_generate_infographic_default(
self, client, generation_notebook_id
):
async def test_generate_infographic_default(self, client, generation_notebook_id):
"""Test infographic generation with non-default orientation to verify param encoding."""
result = await client.artifacts.generate_infographic(
generation_notebook_id,
@ -281,9 +243,7 @@ class TestInfographicGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_infographic_portrait_detailed(
self, client, generation_notebook_id
):
async def test_generate_infographic_portrait_detailed(self, client, generation_notebook_id):
result = await client.artifacts.generate_infographic(
generation_notebook_id,
orientation=InfographicOrientation.PORTRAIT,
@ -294,9 +254,7 @@ class TestInfographicGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_infographic_square_concise(
self, client, generation_notebook_id
):
async def test_generate_infographic_square_concise(self, client, generation_notebook_id):
result = await client.artifacts.generate_infographic(
generation_notebook_id,
orientation=InfographicOrientation.SQUARE,
@ -306,9 +264,7 @@ class TestInfographicGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_infographic_landscape(
self, client, generation_notebook_id
):
async def test_generate_infographic_landscape(self, client, generation_notebook_id):
result = await client.artifacts.generate_infographic(
generation_notebook_id,
orientation=InfographicOrientation.LANDSCAPE,
@ -321,9 +277,7 @@ class TestSlideDeckGeneration:
"""Slide deck generation tests."""
@pytest.mark.asyncio
async def test_generate_slide_deck_default(
self, client, generation_notebook_id
):
async def test_generate_slide_deck_default(self, client, generation_notebook_id):
"""Test slide deck generation with non-default format to verify param encoding."""
result = await client.artifacts.generate_slide_deck(
generation_notebook_id,
@ -333,9 +287,7 @@ class TestSlideDeckGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_slide_deck_detailed(
self, client, generation_notebook_id
):
async def test_generate_slide_deck_detailed(self, client, generation_notebook_id):
result = await client.artifacts.generate_slide_deck(
generation_notebook_id,
slide_format=SlideDeckFormat.DETAILED_DECK,
@ -346,9 +298,7 @@ class TestSlideDeckGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_slide_deck_presenter_short(
self, client, generation_notebook_id
):
async def test_generate_slide_deck_presenter_short(self, client, generation_notebook_id):
result = await client.artifacts.generate_slide_deck(
generation_notebook_id,
slide_format=SlideDeckFormat.PRESENTER_SLIDES,
@ -362,9 +312,7 @@ class TestDataTableGeneration:
"""Data table generation tests."""
@pytest.mark.asyncio
async def test_generate_data_table_default(
self, client, generation_notebook_id
):
async def test_generate_data_table_default(self, client, generation_notebook_id):
"""Test data table generation with instructions to verify param encoding."""
result = await client.artifacts.generate_data_table(
generation_notebook_id,
@ -374,9 +322,7 @@ class TestDataTableGeneration:
@pytest.mark.asyncio
@pytest.mark.variants
async def test_generate_data_table_with_instructions(
self, client, generation_notebook_id
):
async def test_generate_data_table_with_instructions(self, client, generation_notebook_id):
result = await client.artifacts.generate_data_table(
generation_notebook_id,
instructions="Create a comparison table of key concepts",
@ -408,9 +354,7 @@ class TestStudyGuideGeneration:
"""Study guide generation tests."""
@pytest.mark.asyncio
async def test_generate_study_guide(
self, client, generation_notebook_id
):
async def test_generate_study_guide(self, client, generation_notebook_id):
"""Test study guide generation."""
result = await client.artifacts.generate_study_guide(generation_notebook_id)
assert_generation_started(result)

View file

@ -1,6 +1,8 @@
import pytest
from notebooklm import ChatGoal, ChatMode, Notebook, NotebookDescription
from .conftest import requires_auth
from notebooklm import Notebook, NotebookDescription, ChatMode, ChatGoal
@requires_auth

View file

@ -1,6 +1,7 @@
"""E2E tests for NotesAPI."""
import pytest
from .conftest import requires_auth

View file

@ -5,8 +5,10 @@ polling for results, and importing discovered sources.
"""
import asyncio
import pytest
from .conftest import requires_auth, POLL_INTERVAL, POLL_TIMEOUT
from .conftest import POLL_INTERVAL, POLL_TIMEOUT, requires_auth
@requires_auth

View file

@ -1,7 +1,10 @@
import asyncio
import pytest
from notebooklm import Source, SourceStatus
from .conftest import requires_auth
from notebooklm import Source, SourceStatus, SourceTimeoutError
@requires_auth
@ -27,9 +30,7 @@ class TestSourceOperations:
@pytest.mark.asyncio
async def test_add_url_source(self, client, temp_notebook):
"""Test adding a URL source to an owned notebook."""
source = await client.sources.add_url(
temp_notebook.id, "https://httpbin.org/html"
)
source = await client.sources.add_url(temp_notebook.id, "https://httpbin.org/html")
assert isinstance(source, Source)
assert source.id is not None
# URL may or may not be returned in response
@ -59,9 +60,7 @@ class TestSourceOperations:
assert isinstance(source, Source)
# Rename
renamed = await client.sources.rename(
temp_notebook.id, source.id, "Renamed Test Source"
)
renamed = await client.sources.rename(temp_notebook.id, source.id, "Renamed Test Source")
assert isinstance(renamed, Source)
assert renamed.title == "Renamed Test Source"
# No need to restore - temp_notebook is deleted after test
@ -135,9 +134,7 @@ class TestSourceMutations:
async def test_refresh_source(self, client, temp_notebook):
"""Test refreshing a URL source."""
# Add a URL source
source = await client.sources.add_url(
temp_notebook.id, "https://httpbin.org/html"
)
source = await client.sources.add_url(temp_notebook.id, "https://httpbin.org/html")
assert source.id is not None
# Refresh it
@ -151,9 +148,7 @@ class TestSourceMutations:
async def test_check_freshness(self, client, temp_notebook):
"""Test checking source freshness."""
# Add a URL source
source = await client.sources.add_url(
temp_notebook.id, "https://httpbin.org/html"
)
source = await client.sources.add_url(temp_notebook.id, "https://httpbin.org/html")
assert source.id is not None
await asyncio.sleep(2) # Wait for processing

View file

@ -1,7 +1,6 @@
"""Shared fixtures for integration tests."""
import json
from typing import Union
import pytest
@ -34,7 +33,7 @@ def build_rpc_response():
data: The response data to encode.
"""
def _build(rpc_id: Union[RPCMethod, str], data) -> str:
def _build(rpc_id: RPCMethod | str, data) -> str:
# Convert RPCMethod to string value if needed
rpc_id_str = rpc_id.value if isinstance(rpc_id, RPCMethod) else rpc_id
inner = json.dumps(data)

View file

@ -4,7 +4,7 @@ import pytest
from pytest_httpx import HTTPXMock
from notebooklm import NotebookLMClient
from notebooklm.rpc import AudioFormat, AudioLength, VideoFormat, VideoStyle, RPCError, RPCMethod
from notebooklm.rpc import AudioFormat, AudioLength, RPCError, RPCMethod, VideoFormat, VideoStyle
class TestStudioContent:
@ -529,7 +529,15 @@ class TestArtifactsAPI:
RPCMethod.LIST_ARTIFACTS,
[
["art_001", "Quiz", 4, None, 3, None, [None, None, None, None, None, None, 2]],
["art_002", "Flashcards", 4, None, 3, None, [None, None, None, None, None, None, 1]],
[
"art_002",
"Flashcards",
4,
None,
3,
None,
[None, None, None, None, None, None, 1],
],
],
)
httpx_mock.add_response(content=response.encode())
@ -569,7 +577,15 @@ class TestArtifactsAPI:
RPCMethod.LIST_ARTIFACTS,
[
["art_001", "Quiz", 4, None, 3, None, [None, None, None, None, None, None, 2]],
["art_002", "Flashcards", 4, None, 3, None, [None, None, None, None, None, None, 1]],
[
"art_002",
"Flashcards",
4,
None,
3,
None,
[None, None, None, None, None, None, 1],
],
],
)
httpx_mock.add_response(content=response.encode())

View file

@ -4,8 +4,7 @@ import pytest
from pytest_httpx import HTTPXMock
from notebooklm import NotebookLMClient
from notebooklm.rpc import RPCMethod
from notebooklm.rpc import ChatGoal, ChatResponseLength
from notebooklm.rpc import ChatGoal, ChatResponseLength, RPCMethod
from notebooklm.types import ChatMode

View file

@ -5,8 +5,8 @@ to avoid asyncio event loop conflicts with pytest-asyncio.
"""
import pytest
from pathlib import Path
from notebooklm.cli.download_helpers import select_artifact, artifact_title_to_filename
from notebooklm.cli.download_helpers import artifact_title_to_filename, select_artifact
class TestArtifactSelection:
@ -19,7 +19,7 @@ class TestArtifactSelection:
{"id": "a2", "title": "Meeting Notes", "created_at": 2000},
{"id": "a3", "title": "Debate Round 3", "created_at": 3000}, # Latest "debate"
{"id": "a4", "title": "Debate Round 2", "created_at": 2500},
{"id": "a5", "title": "Overview", "created_at": 4000}, # Latest overall
{"id": "a5", "title": "Overview", "created_at": 4000}, # Latest overall
]
selected, reason = select_artifact(artifacts, latest=True, name="debate")
@ -32,9 +32,9 @@ class TestArtifactSelection:
def test_filter_then_select_earliest(self):
"""Should apply name filter BEFORE selecting earliest."""
artifacts = [
{"id": "a1", "title": "Introduction", "created_at": 1000}, # Earliest overall
{"id": "a1", "title": "Introduction", "created_at": 1000}, # Earliest overall
{"id": "a2", "title": "Chapter 2", "created_at": 3000},
{"id": "a3", "title": "Chapter 1", "created_at": 2000}, # Earliest "chapter"
{"id": "a3", "title": "Chapter 1", "created_at": 2000}, # Earliest "chapter"
{"id": "a4", "title": "Chapter 3", "created_at": 4000},
{"id": "a5", "title": "Conclusion", "created_at": 5000},
]
@ -175,7 +175,7 @@ class TestFilenameGeneration:
assert "/" not in filename
assert ":" not in filename
assert '"' not in filename
assert filename == 'Audio_ Part 1 _ _Main_.mp3'
assert filename == "Audio_ Part 1 _ _Main_.mp3"
def test_handle_duplicates(self):
"""Should add (2), (3) suffixes for duplicates."""

View file

@ -3,7 +3,7 @@
import pytest
from pytest_httpx import HTTPXMock
from notebooklm import NotebookLMClient, Notebook
from notebooklm import Notebook, NotebookLMClient
from notebooklm.rpc import RPCMethod
@ -199,9 +199,7 @@ class TestSummary:
httpx_mock: HTTPXMock,
build_rpc_response,
):
response = build_rpc_response(
RPCMethod.SUMMARIZE, ["Summary of the notebook content..."]
)
response = build_rpc_response(RPCMethod.SUMMARIZE, ["Summary of the notebook content..."])
httpx_mock.add_response(content=response.encode())
async with NotebookLMClient(auth_tokens) as client:
@ -257,7 +255,16 @@ class TestRenameNotebook:
# Get notebook response after rename
get_response = build_rpc_response(
RPCMethod.GET_NOTEBOOK,
[["Renamed", [], "nb_123", "📘", None, [None, None, None, None, None, [1704067200, 0]]]],
[
[
"Renamed",
[],
"nb_123",
"📘",
None,
[None, None, None, None, None, [1704067200, 0]],
]
],
)
httpx_mock.add_response(content=get_response.encode())
@ -364,10 +371,12 @@ class TestNotebooksAPIAdditional:
RPCMethod.SUMMARIZE,
[
["This notebook covers AI research."],
[[
["What are the main findings?", "Explain the key findings"],
["How was the study conducted?", "Describe methodology"],
]],
[
[
["What are the main findings?", "Explain the key findings"],
["How was the study conducted?", "Describe methodology"],
]
],
],
)
httpx_mock.add_response(content=response.encode())
@ -464,11 +473,13 @@ class TestNotebookEdgeCases:
RPCMethod.SUMMARIZE,
[
["Summary"],
[[
["Valid question", "Valid prompt"],
["Only question"], # Missing prompt
"not a list", # Not a list
]],
[
[
["Valid question", "Valid prompt"],
["Only question"], # Missing prompt
"not a list", # Not a list
]
],
],
)
httpx_mock.add_response(content=response.encode())

View file

@ -68,7 +68,10 @@ class TestNotesAPI:
[
[
["note_001", ["note_001", "Regular note content", None, None, "Regular Note"]],
["mm_001", ["mm_001", '{"title":"Mind Map","children":[]}', None, None, "Mind Map"]],
[
"mm_001",
["mm_001", '{"title":"Mind Map","children":[]}', None, None, "Mind Map"],
],
]
],
)
@ -204,7 +207,10 @@ class TestNotesAPI:
[
[
["note_001", ["note_001", "Regular note", None, None, "Note"]],
["mm_001", ["mm_001", '{"title":"Mind Map 1","children":[]}', None, None, "MM1"]],
[
"mm_001",
["mm_001", '{"title":"Mind Map 1","children":[]}', None, None, "MM1"],
],
["mm_002", ["mm_002", '{"nodes":[{"id":"1"}]}', None, None, "MM2"]],
]
],

View file

@ -17,9 +17,7 @@ class TestResearchAPI:
build_rpc_response,
):
"""Test starting fast web research."""
response = build_rpc_response(
"Ljjv0c", ["task_123", "report_456"]
)
response = build_rpc_response("Ljjv0c", ["task_123", "report_456"])
httpx_mock.add_response(content=response.encode())
async with NotebookLMClient(auth_tokens) as client:
@ -43,9 +41,7 @@ class TestResearchAPI:
build_rpc_response,
):
"""Test starting fast drive research."""
response = build_rpc_response(
"Ljjv0c", ["task_789", None]
)
response = build_rpc_response("Ljjv0c", ["task_789", None])
httpx_mock.add_response(content=response.encode())
async with NotebookLMClient(auth_tokens) as client:
@ -65,15 +61,11 @@ class TestResearchAPI:
build_rpc_response,
):
"""Test starting deep web research."""
response = build_rpc_response(
"QA9ei", ["task_deep", "report_deep"]
)
response = build_rpc_response("QA9ei", ["task_deep", "report_deep"])
httpx_mock.add_response(content=response.encode())
async with NotebookLMClient(auth_tokens) as client:
result = await client.research.start(
"nb_123", "AI ethics", source="web", mode="deep"
)
result = await client.research.start("nb_123", "AI ethics", source="web", mode="deep")
assert result is not None
assert result["mode"] == "deep"
@ -90,9 +82,7 @@ class TestResearchAPI:
"""Test that deep research on drive raises ValueError."""
async with NotebookLMClient(auth_tokens) as client:
with pytest.raises(ValueError, match="Deep Research only supports Web"):
await client.research.start(
"nb_123", "query", source="drive", mode="deep"
)
await client.research.start("nb_123", "query", source="drive", mode="deep")
@pytest.mark.asyncio
async def test_start_invalid_source_raises(
@ -228,9 +218,7 @@ class TestResearchAPI:
{"url": "https://example.com/quantum", "title": "Quantum Computing Guide"},
{"url": "https://example.com/ai", "title": "AI Research Paper"},
]
result = await client.research.import_sources(
"nb_123", "task_123", sources_to_import
)
result = await client.research.import_sources("nb_123", "task_123", sources_to_import)
assert len(result) == 2
assert result[0]["id"] == "src_001"

View file

@ -53,9 +53,7 @@ class TestAddSource:
httpx_mock.add_response(content=response.encode())
async with NotebookLMClient(auth_tokens) as client:
source = await client.sources.add_text(
"nb_123", "My Document", "This is the content"
)
source = await client.sources.add_text("nb_123", "My Document", "This is the content")
assert isinstance(source, Source)
assert source.id == "source_id"
@ -148,9 +146,37 @@ class TestSourcesAPI:
[
"Test Notebook",
[
[["src_001"], "My Article", [None, 11, [1704067200, 0], None, 5, None, None, ["https://example.com"]], [None, 2]],
[
["src_001"],
"My Article",
[
None,
11,
[1704067200, 0],
None,
5,
None,
None,
["https://example.com"],
],
[None, 2],
],
[["src_002"], "My Text", [None, 0, [1704153600, 0]], [None, 2]],
[["src_003"], "YouTube Video", [None, 11, [1704240000, 0], None, 5, None, None, ["https://youtube.com/watch?v=abc"]], [None, 2]],
[
["src_003"],
"YouTube Video",
[
None,
11,
[1704240000, 0],
None,
5,
None,
None,
["https://youtube.com/watch?v=abc"],
],
[None, 2],
],
],
"nb_123",
"📘",
@ -180,7 +206,16 @@ class TestSourcesAPI:
"""Test listing sources from empty notebook."""
response = build_rpc_response(
RPCMethod.GET_NOTEBOOK,
[["Empty Notebook", [], "nb_123", "📘", None, [None, None, None, None, None, [1704067200, 0]]]],
[
[
"Empty Notebook",
[],
"nb_123",
"📘",
None,
[None, None, None, None, None, [1704067200, 0]],
]
],
)
httpx_mock.add_response(content=response.encode())
@ -199,7 +234,16 @@ class TestSourcesAPI:
"""Test getting a non-existent source."""
response = build_rpc_response(
RPCMethod.GET_NOTEBOOK,
[["Notebook", [[["src_001"], "Source 1", [None, 0], [None, 2]]], "nb_123", "📘", None, [None, None, None, None, None, [1704067200, 0]]]],
[
[
"Notebook",
[[["src_001"], "Source 1", [None, 0], [None, 2]]],
"nb_123",
"📘",
None,
[None, None, None, None, None, [1704067200, 0]],
]
],
)
httpx_mock.add_response(content=response.encode())
@ -367,11 +411,7 @@ class TestAddFileSource:
# Step 1: Mock RPC registration response (o4cbdc)
rpc_response = build_rpc_response(
RPCMethod.ADD_SOURCE_FILE,
[
[
[["file_source_123"], "test_document.txt", [None, None, None, None, 0]]
]
],
[[[["file_source_123"], "test_document.txt", [None, None, None, None, 0]]]],
)
httpx_mock.add_response(
url=re.compile(r".*batchexecute.*"),
@ -504,6 +544,7 @@ class TestAddFileSource:
# Verify body contains metadata
import json
body = json.loads(start_request.content.decode())
assert body["PROJECT_ID"] == "nb_123"
assert body["SOURCE_NAME"] == "document.txt"

View file

@ -1,8 +1,9 @@
"""Shared fixtures for CLI unit tests."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from click.testing import CliRunner
from unittest.mock import AsyncMock, patch, MagicMock
@pytest.fixture
@ -72,6 +73,24 @@ def create_mock_client():
return mock_client
def get_cli_module(module_path: str):
"""Get the actual CLI module by path, bypassing shadowed names.
In cli/__init__.py, module names are shadowed by click groups with the same name
(e.g., `from .source import source`). This function uses importlib to get the
actual module for Python 3.10 compatibility.
Args:
module_path: The module name within notebooklm.cli (e.g., "source", "skill")
Returns:
The actual module object
"""
import importlib
return importlib.import_module(f"notebooklm.cli.{module_path}")
def patch_client_for_module(module_path: str):
"""Create a context manager that patches NotebookLMClient in the given module.
@ -86,8 +105,16 @@ def patch_client_for_module(module_path: str):
mock_client = create_mock_client()
mock_cls.return_value = mock_client
# ... run test
Note:
Uses importlib to get the actual module, not the click group that shadows
the module name in cli/__init__.py. This is required for Python 3.10
compatibility where mock.patch's string path resolution gets the wrong object.
"""
return patch(f"notebooklm.cli.{module_path}.NotebookLMClient")
import importlib
module = importlib.import_module(f"notebooklm.cli.{module_path}")
return patch.object(module, "NotebookLMClient")
class MultiMockProxy:
@ -96,15 +123,16 @@ class MultiMockProxy:
When you set return_value on this proxy, it propagates to all mocks.
Other attribute access is delegated to the primary mock.
"""
def __init__(self, mocks):
object.__setattr__(self, '_mocks', mocks)
object.__setattr__(self, '_primary', mocks[0])
object.__setattr__(self, "_mocks", mocks)
object.__setattr__(self, "_primary", mocks[0])
def __getattr__(self, name):
return getattr(self._primary, name)
def __setattr__(self, name, value):
if name == 'return_value':
if name == "return_value":
# Propagate return_value to all mocks
for m in self._mocks:
m.return_value = value
@ -118,6 +146,7 @@ class MultiPatcher:
After refactoring, commands are spread across multiple modules, so we need
to patch NotebookLMClient in all of them.
"""
def __init__(self):
self.patches = [
patch("notebooklm.cli.notebook.NotebookLMClient"),

View file

@ -1,14 +1,14 @@
"""Tests for artifact CLI commands."""
import json
import pytest
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
from notebooklm.types import Artifact, ReportSuggestion
from notebooklm.types import Artifact
from .conftest import create_mock_client, patch_client_for_module
@ -84,9 +84,7 @@ class TestArtifactList:
]
)
mock_client.notes.list_mind_maps = AsyncMock(return_value=[])
mock_client.notebooks.get = AsyncMock(
return_value=MagicMock(title="Test Notebook")
)
mock_client.notebooks.get = AsyncMock(return_value=MagicMock(title="Test Notebook"))
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
@ -120,7 +118,7 @@ class TestArtifactGet:
title="Test Artifact",
artifact_type=4,
status=3,
created_at=datetime(2024, 1, 1)
created_at=datetime(2024, 1, 1),
)
)
mock_client_cls.return_value = mock_client
@ -161,9 +159,7 @@ class TestArtifactRename:
mock_client = create_mock_client()
# Mock list for partial ID resolution
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Old Title", artifact_type=4, status=3)
]
return_value=[Artifact(id="art_123", title="Old Title", artifact_type=4, status=3)]
)
mock_client.notes.list_mind_maps = AsyncMock(return_value=[])
mock_client.artifacts.rename = AsyncMock(
@ -185,9 +181,7 @@ class TestArtifactRename:
mock_client = create_mock_client()
# Mock list for partial ID resolution (include the mind map)
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="mm_123", title="Old Title", artifact_type=5, status=3)
]
return_value=[Artifact(id="mm_123", title="Old Title", artifact_type=5, status=3)]
)
mock_client.notes.list_mind_maps = AsyncMock(
return_value=[
@ -227,9 +221,7 @@ class TestArtifactDelete:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["artifact", "delete", "art_123", "-n", "nb_123", "-y"]
)
result = runner.invoke(cli, ["artifact", "delete", "art_123", "-n", "nb_123", "-y"])
assert result.exit_code == 0
assert "Deleted artifact" in result.output
@ -253,9 +245,7 @@ class TestArtifactDelete:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["artifact", "delete", "mm_456", "-n", "nb_123", "-y"]
)
result = runner.invoke(cli, ["artifact", "delete", "mm_456", "-n", "nb_123", "-y"])
assert result.exit_code == 0
assert "Cleared mind map" in result.output
@ -273,9 +263,7 @@ class TestArtifactExport:
mock_client = create_mock_client()
# Mock list for partial ID resolution
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Doc", artifact_type=2, status=3)
]
return_value=[Artifact(id="art_123", title="Doc", artifact_type=2, status=3)]
)
mock_client.artifacts.export = AsyncMock(
return_value={"url": "https://docs.google.com/document/d/123"}
@ -294,6 +282,7 @@ class TestArtifactExport:
mock_client.artifacts.export.assert_called_once()
call_args = mock_client.artifacts.export.call_args
from notebooklm.rpc import ExportType
# call_args[0] = (notebook_id, artifact_id, content, title, export_type)
assert call_args[0][2] is None, "content should be None (backend retrieves it)"
assert call_args[0][4] == ExportType.DOCS, "export_type should be ExportType.DOCS"
@ -303,9 +292,7 @@ class TestArtifactExport:
mock_client = create_mock_client()
# Mock list for partial ID resolution
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Table", artifact_type=9, status=3)
]
return_value=[Artifact(id="art_123", title="Table", artifact_type=9, status=3)]
)
mock_client.artifacts.export = AsyncMock(
return_value={"url": "https://sheets.google.com/spreadsheets/d/123"}
@ -315,7 +302,18 @@ class TestArtifactExport:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["artifact", "export", "art_123", "--title", "My Sheet", "--type", "sheets", "-n", "nb_123"]
cli,
[
"artifact",
"export",
"art_123",
"--title",
"My Sheet",
"--type",
"sheets",
"-n",
"nb_123",
],
)
assert result.exit_code == 0
@ -324,6 +322,7 @@ class TestArtifactExport:
mock_client.artifacts.export.assert_called_once()
call_args = mock_client.artifacts.export.call_args
from notebooklm.rpc import ExportType
# call_args[0] = (notebook_id, artifact_id, content, title, export_type)
assert call_args[0][2] is None, "content should be None (backend retrieves it)"
assert call_args[0][4] == ExportType.SHEETS, "export_type should be ExportType.SHEETS"
@ -333,9 +332,7 @@ class TestArtifactExport:
mock_client = create_mock_client()
# Mock list for partial ID resolution
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Doc", artifact_type=2, status=3)
]
return_value=[Artifact(id="art_123", title="Doc", artifact_type=2, status=3)]
)
mock_client.artifacts.export = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
@ -384,24 +381,18 @@ class TestArtifactWait:
mock_client = create_mock_client()
# Mock list for partial ID resolution
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Test", artifact_type=1, status=3)
]
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=3)]
)
mock_client.artifacts.wait_for_completion = AsyncMock(
return_value=MagicMock(
status="completed",
url="https://example.com/audio.mp3",
error=None
status="completed", url="https://example.com/audio.mp3", error=None
)
)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["artifact", "wait", "art_123", "-n", "nb_123"]
)
result = runner.invoke(cli, ["artifact", "wait", "art_123", "-n", "nb_123"])
assert result.exit_code == 0
assert "Artifact completed" in result.output
@ -411,24 +402,18 @@ class TestArtifactWait:
with patch_client_for_module("artifact") as mock_client_cls:
mock_client = create_mock_client()
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Test", artifact_type=1, status=1)
]
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=1)]
)
mock_client.artifacts.wait_for_completion = AsyncMock(
return_value=MagicMock(
status="failed",
url=None,
error="Generation failed due to content policy"
status="failed", url=None, error="Generation failed due to content policy"
)
)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["artifact", "wait", "art_123", "-n", "nb_123"]
)
result = runner.invoke(cli, ["artifact", "wait", "art_123", "-n", "nb_123"])
assert result.exit_code == 1
assert "Generation failed" in result.output
@ -438,9 +423,7 @@ class TestArtifactWait:
with patch_client_for_module("artifact") as mock_client_cls:
mock_client = create_mock_client()
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Test", artifact_type=1, status=1)
]
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=1)]
)
mock_client.artifacts.wait_for_completion = AsyncMock(
side_effect=TimeoutError("Timed out")
@ -461,15 +444,11 @@ class TestArtifactWait:
with patch_client_for_module("artifact") as mock_client_cls:
mock_client = create_mock_client()
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Test", artifact_type=1, status=3)
]
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=3)]
)
mock_client.artifacts.wait_for_completion = AsyncMock(
return_value=MagicMock(
status="completed",
url="https://example.com/audio.mp3",
error=None
status="completed", url="https://example.com/audio.mp3", error=None
)
)
mock_client_cls.return_value = mock_client
@ -490,9 +469,7 @@ class TestArtifactWait:
with patch_client_for_module("artifact") as mock_client_cls:
mock_client = create_mock_client()
mock_client.artifacts.list = AsyncMock(
return_value=[
Artifact(id="art_123", title="Test", artifact_type=1, status=1)
]
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=1)]
)
mock_client.artifacts.wait_for_completion = AsyncMock(
side_effect=TimeoutError("Timed out")

View file

@ -1,19 +1,24 @@
"""Tests for download CLI commands."""
import pytest
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
from notebooklm.types import Artifact
from .conftest import create_mock_client, patch_client_for_module
from .conftest import create_mock_client, get_cli_module, patch_client_for_module
# Get the actual download module (not the click group that shadows it)
download_module = get_cli_module("download")
def make_artifact(id: str, title: str, artifact_type: int, status: int = 3, created_at: datetime = None) -> Artifact:
def make_artifact(
id: str, title: str, artifact_type: int, status: int = 3, created_at: datetime = None
) -> Artifact:
"""Create an Artifact for testing."""
return Artifact(
id=id,
@ -65,8 +70,10 @@ class TestDownloadAudio:
mock_client.artifacts.download_audio = mock_download_audio
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
@ -84,13 +91,13 @@ class TestDownloadAudio:
)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["download", "audio", "--dry-run", "-n", "nb_123"]
)
result = runner.invoke(cli, ["download", "audio", "--dry-run", "-n", "nb_123"])
assert result.exit_code == 0
assert "DRY RUN" in result.output
@ -101,8 +108,10 @@ class TestDownloadAudio:
mock_client.artifacts.list = AsyncMock(return_value=[])
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["download", "audio", "-n", "nb_123"])
@ -133,8 +142,10 @@ class TestDownloadVideo:
mock_client.artifacts.download_video = mock_download_video
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
@ -168,8 +179,10 @@ class TestDownloadInfographic:
mock_client.artifacts.download_infographic = mock_download_infographic
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
@ -204,8 +217,10 @@ class TestDownloadSlideDeck:
mock_client.artifacts.download_slide_deck = mock_download_slide_deck
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
@ -235,15 +250,21 @@ class TestDownloadFlags:
# Set up artifacts namespace (pre-created by create_mock_client)
mock_client.artifacts.list = AsyncMock(
return_value=[
make_artifact("audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)),
make_artifact("audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)),
make_artifact(
"audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)
),
make_artifact(
"audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)
),
]
)
mock_client.artifacts.download_audio = mock_download_audio
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
@ -266,15 +287,21 @@ class TestDownloadFlags:
# Set up artifacts namespace (pre-created by create_mock_client)
mock_client.artifacts.list = AsyncMock(
return_value=[
make_artifact("audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)),
make_artifact("audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)),
make_artifact(
"audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)
),
make_artifact(
"audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)
),
]
)
mock_client.artifacts.download_audio = mock_download_audio
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
@ -302,8 +329,10 @@ class TestDownloadFlags:
mock_client.artifacts.download_audio = mock_download_audio
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
@ -326,8 +355,10 @@ class TestDownloadFlags:
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
with patch.object(
download_module, "fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch.object(download_module, "load_auth_from_storage") as mock_load:
mock_load.return_value = {"SID": "test"}
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(

View file

@ -1,9 +1,9 @@
"""Tests for generate CLI commands."""
import json
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
@ -60,7 +60,9 @@ class TestGenerateAudio:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["generate", "audio", "--format", "debate", "-n", "nb_123"])
result = runner.invoke(
cli, ["generate", "audio", "--format", "debate", "-n", "nb_123"]
)
assert result.exit_code == 0
mock_client.artifacts.generate_audio.assert_called()
@ -75,7 +77,9 @@ class TestGenerateAudio:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["generate", "audio", "--length", "long", "-n", "nb_123"])
result = runner.invoke(
cli, ["generate", "audio", "--length", "long", "-n", "nb_123"]
)
assert result.exit_code == 0
@ -160,7 +164,9 @@ class TestGenerateVideo:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["generate", "video", "--style", "kawaii", "-n", "nb_123"])
result = runner.invoke(
cli, ["generate", "video", "--style", "kawaii", "-n", "nb_123"]
)
assert result.exit_code == 0
@ -196,7 +202,17 @@ class TestGenerateQuiz:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["generate", "quiz", "--quantity", "more", "--difficulty", "hard", "-n", "nb_123"]
cli,
[
"generate",
"quiz",
"--quantity",
"more",
"--difficulty",
"hard",
"-n",
"nb_123",
],
)
assert result.exit_code == 0
@ -254,7 +270,17 @@ class TestGenerateSlideDeck:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["generate", "slide-deck", "--format", "presenter", "--length", "short", "-n", "nb_123"]
cli,
[
"generate",
"slide-deck",
"--format",
"presenter",
"--length",
"short",
"-n",
"nb_123",
],
)
assert result.exit_code == 0
@ -291,7 +317,17 @@ class TestGenerateInfographic:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["generate", "infographic", "--orientation", "portrait", "--detail", "detailed", "-n", "nb_123"]
cli,
[
"generate",
"infographic",
"--orientation",
"portrait",
"--detail",
"detailed",
"-n",
"nb_123",
],
)
assert result.exit_code == 0

View file

@ -88,9 +88,15 @@ class TestSectionedHelp:
# (it may still appear if Click adds it, but our sections should dominate)
lines = result.output.split("\n")
# Count section headers
section_count = sum(1 for line in lines if line.strip().endswith(":") and
any(s in line for s in ["Session", "Notebooks", "Chat",
"Command Groups", "Artifact Actions"]))
section_count = sum(
1
for line in lines
if line.strip().endswith(":")
and any(
s in line
for s in ["Session", "Notebooks", "Chat", "Command Groups", "Artifact Actions"]
)
)
assert section_count >= 4 # At least 4 of our sections should appear (no Insights anymore)

View file

@ -1,39 +1,37 @@
"""Tests for CLI helper functions."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from io import StringIO
from unittest.mock import patch, MagicMock, AsyncMock
from notebooklm.cli.helpers import (
ARTIFACT_TYPE_MAP,
clear_context,
detect_source_type,
# Type display helpers
get_artifact_type_display,
get_source_type_display,
detect_source_type,
ARTIFACT_TYPE_DISPLAY,
ARTIFACT_TYPE_MAP,
# Output helpers
json_output_response,
json_error_response,
# Context helpers
get_current_notebook,
set_current_notebook,
clear_context,
get_current_conversation,
set_current_conversation,
require_notebook,
# Error handling
handle_error,
handle_auth_error,
get_auth_tokens,
# Auth helpers
get_client,
get_auth_tokens,
get_current_conversation,
# Context helpers
get_current_notebook,
get_source_type_display,
handle_auth_error,
# Error handling
handle_error,
json_error_response,
# Output helpers
json_output_response,
require_notebook,
run_async,
set_current_conversation,
set_current_notebook,
# Decorator
with_client,
)
# =============================================================================
# ARTIFACT TYPE DISPLAY TESTS
# =============================================================================
@ -92,15 +90,27 @@ class TestGetArtifactTypeDisplay:
class TestDetectSourceType:
def test_youtube_url(self):
src = ["id", "Video Title", [None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]]]
src = [
"id",
"Video Title",
[None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]],
]
assert detect_source_type(src) == "🎥 YouTube"
def test_youtu_be_url(self):
src = ["id", "Video Title", [None, None, None, None, None, None, None, ["https://youtu.be/abc"]]]
src = [
"id",
"Video Title",
[None, None, None, None, None, None, None, ["https://youtu.be/abc"]],
]
assert detect_source_type(src) == "🎥 YouTube"
def test_web_url(self):
src = ["id", "Web Page", [None, None, None, None, None, None, None, ["https://example.com/article"]]]
src = [
"id",
"Web Page",
[None, None, None, None, None, None, None, ["https://example.com/article"]],
]
assert detect_source_type(src) == "🔗 Web URL"
def test_pdf_file(self):
@ -256,7 +266,9 @@ class TestJsonErrorResponse:
class TestContextManagement:
def test_get_current_notebook_no_file(self, tmp_path):
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"):
with patch(
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"
):
result = get_current_notebook()
assert result is None
@ -271,10 +283,7 @@ class TestContextManagement:
context_file = tmp_path / "context.json"
with patch("notebooklm.cli.helpers.get_context_path", return_value=context_file):
set_current_notebook(
"nb_test123",
title="Test Notebook",
is_owner=True,
created_at="2024-01-01T00:00:00"
"nb_test123", title="Test Notebook", is_owner=True, created_at="2024-01-01T00:00:00"
)
data = json.loads(context_file.read_text())
assert data["notebook_id"] == "nb_test123"
@ -296,7 +305,9 @@ class TestContextManagement:
clear_context() # Should not raise
def test_get_current_conversation_no_file(self, tmp_path):
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"):
with patch(
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"
):
result = get_current_conversation()
assert result is None
@ -327,7 +338,9 @@ class TestContextManagement:
class TestRequireNotebook:
def test_returns_provided_notebook_id(self, tmp_path):
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "context.json"):
with patch(
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "context.json"
):
result = require_notebook("nb_provided")
assert result == "nb_provided"
@ -339,7 +352,9 @@ class TestRequireNotebook:
assert result == "nb_context"
def test_raises_system_exit_when_no_notebook(self, tmp_path):
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"):
with patch(
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"
):
with patch("notebooklm.cli.helpers.console") as mock_console:
with pytest.raises(SystemExit) as exc_info:
require_notebook(None)
@ -395,14 +410,15 @@ class TestHandleAuthError:
class TestWithClientDecorator:
def test_decorator_passes_auth_to_function(self):
"""Test that @with_client properly injects client_auth"""
from click.testing import CliRunner
import click
from click.testing import CliRunner
@click.command()
@with_client
def test_cmd(ctx, client_auth):
async def _run():
click.echo(f"Got auth: {client_auth is not None}")
return _run()
runner = CliRunner()
@ -417,14 +433,15 @@ class TestWithClientDecorator:
def test_decorator_handles_no_auth(self):
"""Test that @with_client handles missing auth gracefully"""
from click.testing import CliRunner
import click
from click.testing import CliRunner
@click.command()
@with_client
def test_cmd(ctx, client_auth):
async def _run():
pass
return _run()
runner = CliRunner()
@ -437,14 +454,15 @@ class TestWithClientDecorator:
def test_decorator_handles_exception_non_json(self):
"""Test error handling in non-JSON mode"""
from click.testing import CliRunner
import click
from click.testing import CliRunner
@click.command()
@with_client
def test_cmd(ctx, client_auth):
async def _run():
raise ValueError("Test error")
return _run()
runner = CliRunner()
@ -459,8 +477,8 @@ class TestWithClientDecorator:
def test_decorator_handles_exception_json_mode(self):
"""Test error handling in JSON mode"""
from click.testing import CliRunner
import click
from click.testing import CliRunner
@click.command()
@click.option("--json", "json_output", is_flag=True)
@ -468,6 +486,7 @@ class TestWithClientDecorator:
def test_cmd(ctx, json_output, client_auth):
async def _run():
raise ValueError("Test error")
return _run()
runner = CliRunner()

View file

@ -1,8 +1,8 @@
"""Tests for note CLI commands."""
import pytest
from unittest.mock import AsyncMock, patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
@ -267,9 +267,7 @@ class TestNoteDelete:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["note", "delete", "note_123", "-n", "nb_123", "-y"]
)
result = runner.invoke(cli, ["note", "delete", "note_123", "-n", "nb_123", "-y"])
assert result.exit_code == 0
assert "Deleted note" in result.output

View file

@ -1,17 +1,17 @@
"""Tests for notebook CLI commands (now top-level commands)."""
import json
import pytest
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
from notebooklm.types import Notebook, NotebookDescription, SuggestedTopic, AskResult
from notebooklm.types import AskResult, Notebook
from .conftest import create_mock_client, patch_main_cli_client, patch_client_for_module
from .conftest import create_mock_client, patch_client_for_module, patch_main_cli_client
@pytest.fixture
@ -56,8 +56,18 @@ class TestNotebookList:
mock_client = create_mock_client()
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_1", title="First Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(id="nb_2", title="Second Notebook", created_at=datetime(2024, 1, 2), is_owner=False),
Notebook(
id="nb_1",
title="First Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
Notebook(
id="nb_2",
title="Second Notebook",
created_at=datetime(2024, 1, 2),
is_owner=False,
),
]
)
mock_client_cls.return_value = mock_client
@ -75,7 +85,12 @@ class TestNotebookList:
mock_client = create_mock_client()
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_1", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_1",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client_cls.return_value = mock_client
@ -101,7 +116,9 @@ class TestNotebookCreate:
with patch_main_cli_client() as mock_client_cls:
mock_client = create_mock_client()
mock_client.notebooks.create = AsyncMock(
return_value=Notebook(id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1))
return_value=Notebook(
id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1)
)
)
mock_client_cls.return_value = mock_client
@ -116,7 +133,9 @@ class TestNotebookCreate:
with patch_main_cli_client() as mock_client_cls:
mock_client = create_mock_client()
mock_client.notebooks.create = AsyncMock(
return_value=Notebook(id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1))
return_value=Notebook(
id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1)
)
)
mock_client_cls.return_value = mock_client
@ -141,7 +160,12 @@ class TestNotebookDelete:
# Mock list for partial ID resolution (returns the notebook to be deleted)
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_to_delete", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_to_delete",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client.notebooks.delete = AsyncMock(return_value=True)
@ -164,16 +188,25 @@ class TestNotebookDelete:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_to_delete", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_to_delete",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client.notebooks.delete = AsyncMock(return_value=True)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.get_context_path", return_value=context_file):
with patch("notebooklm.cli.notebook.get_current_notebook", return_value="nb_to_delete"):
with patch(
"notebooklm.cli.notebook.get_current_notebook", return_value="nb_to_delete"
):
with patch("notebooklm.cli.notebook.clear_context") as mock_clear:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch(
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["delete", "-n", "nb_to_delete", "-y"])
@ -186,7 +219,12 @@ class TestNotebookDelete:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_123",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client.notebooks.delete = AsyncMock(return_value=False)
@ -212,7 +250,12 @@ class TestNotebookRename:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_123",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client.notebooks.rename = AsyncMock(return_value=None)
@ -239,7 +282,12 @@ class TestNotebookShare:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_123",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client.notebooks.share = AsyncMock(
@ -266,7 +314,12 @@ class TestNotebookShare:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_123",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client.notebooks.share = AsyncMock(
@ -299,7 +352,12 @@ class TestNotebookSummary:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_123",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_desc = MagicMock()
@ -322,7 +380,12 @@ class TestNotebookSummary:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_123",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_desc = MagicMock()
@ -347,7 +410,12 @@ class TestNotebookSummary:
# Mock list for partial ID resolution
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(
id="nb_123",
title="Test Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
mock_client.notebooks.get_description = AsyncMock(return_value=None)
@ -370,9 +438,7 @@ class TestNotebookHistory:
def test_notebook_history(self, runner, mock_auth):
with patch_main_cli_client() as mock_client_cls:
mock_client = create_mock_client()
mock_client.chat.get_history = AsyncMock(
return_value=[[["conv_1"], ["conv_2"]]]
)
mock_client.chat.get_history = AsyncMock(return_value=[[["conv_1"], ["conv_2"]]])
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
@ -429,8 +495,13 @@ class TestNotebookAsk:
mock_client.chat.get_history = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.get_context_path", return_value=Path("/nonexistent/context.json")):
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
with patch(
"notebooklm.cli.helpers.get_context_path",
return_value=Path("/nonexistent/context.json"),
):
with patch(
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["ask", "-n", "nb_123", "What is this?"])
@ -455,7 +526,9 @@ class TestNotebookAsk:
result = runner.invoke(cli, ["ask", "-n", "nb_123", "--new", "Fresh question"])
assert result.exit_code == 0
assert "Starting new conversation" in result.output or "New conversation" in result.output
assert (
"Starting new conversation" in result.output or "New conversation" in result.output
)
def test_notebook_ask_continue_conversation(self, runner, mock_auth):
with patch_main_cli_client() as mock_client_cls:
@ -492,7 +565,9 @@ class TestNotebookConfigure:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["configure", "-n", "nb_123", "--mode", "learning-guide"])
result = runner.invoke(
cli, ["configure", "-n", "nb_123", "--mode", "learning-guide"]
)
assert result.exit_code == 0
assert "Chat mode set to: learning-guide" in result.output
@ -505,7 +580,9 @@ class TestNotebookConfigure:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["configure", "-n", "nb_123", "--persona", "Act as a tutor"])
result = runner.invoke(
cli, ["configure", "-n", "nb_123", "--persona", "Act as a tutor"]
)
assert result.exit_code == 0
assert "Chat configured" in result.output
@ -519,7 +596,9 @@ class TestNotebookConfigure:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["configure", "-n", "nb_123", "--response-length", "longer"])
result = runner.invoke(
cli, ["configure", "-n", "nb_123", "--response-length", "longer"]
)
assert result.exit_code == 0
assert "response length: longer" in result.output
@ -542,7 +621,9 @@ class TestSourceAddResearch:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["source", "add-research", "AI research", "-n", "nb_123"])
result = runner.invoke(
cli, ["source", "add-research", "AI research", "-n", "nb_123"]
)
assert result.exit_code == 0
assert "Found 1 sources" in result.output
@ -555,7 +636,9 @@ class TestSourceAddResearch:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["source", "add-research", "AI research", "-n", "nb_123"])
result = runner.invoke(
cli, ["source", "add-research", "AI research", "-n", "nb_123"]
)
assert result.exit_code == 1
assert "Research failed to start" in result.output
@ -572,7 +655,9 @@ class TestSourceAddResearch:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["source", "add-research", "AI research", "-n", "nb_123", "--import-all"])
result = runner.invoke(
cli, ["source", "add-research", "AI research", "-n", "nb_123", "--import-all"]
)
assert result.exit_code == 0
assert "Imported 1 sources" in result.output
@ -619,5 +704,7 @@ class TestNotebookCommandsExist:
assert "ask" in result.output
# Verify there's no "notebook" command in the Commands section
# (it should only appear as part of "NotebookLM" in the description)
commands_section = result.output.split("Commands:")[1] if "Commands:" in result.output else ""
commands_section = (
result.output.split("Commands:")[1] if "Commands:" in result.output else ""
)
assert " notebook " not in commands_section.lower()

View file

@ -1,10 +1,10 @@
"""Tests for resolve_notebook_id and resolve_source_id partial ID matching."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import click
import pytest
from notebooklm.cli.helpers import resolve_notebook_id, resolve_source_id
from notebooklm.types import Notebook, Source
@ -22,9 +22,24 @@ def mock_client():
def sample_notebooks():
"""Sample notebooks for testing."""
return [
Notebook(id="abc123def456ghi789", title="First Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
Notebook(id="xyz789uvw456rst123", title="Second Notebook", created_at=datetime(2024, 1, 2), is_owner=False),
Notebook(id="abc999zzz888yyy777", title="Third Notebook", created_at=datetime(2024, 1, 3), is_owner=True),
Notebook(
id="abc123def456ghi789",
title="First Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
Notebook(
id="xyz789uvw456rst123",
title="Second Notebook",
created_at=datetime(2024, 1, 2),
is_owner=False,
),
Notebook(
id="abc999zzz888yyy777",
title="Third Notebook",
created_at=datetime(2024, 1, 3),
is_owner=True,
),
]
@ -127,9 +142,16 @@ class TestResolveNotebookId:
mock_client.notebooks.list = AsyncMock(return_value=sample_notebooks)
# Create a notebook with a short ID that we'll match exactly
mock_client.notebooks.list = AsyncMock(return_value=[
Notebook(id="shortid", title="Short ID Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
])
mock_client.notebooks.list = AsyncMock(
return_value=[
Notebook(
id="shortid",
title="Short ID Notebook",
created_at=datetime(2024, 1, 1),
is_owner=True,
),
]
)
with patch("notebooklm.cli.helpers.console") as mock_console:
result = await resolve_notebook_id(mock_client, "shortid")
@ -146,7 +168,12 @@ class TestResolveNotebookIdAmbiguityDisplay:
async def test_shows_up_to_five_matches(self, mock_client):
"""Ambiguous error shows up to 5 matching notebooks."""
notebooks = [
Notebook(id=f"abc{i}00000000000000", title=f"Notebook {i}", created_at=datetime(2024, 1, i + 1), is_owner=True)
Notebook(
id=f"abc{i}00000000000000",
title=f"Notebook {i}",
created_at=datetime(2024, 1, i + 1),
is_owner=True,
)
for i in range(7)
]
mock_client.notebooks.list = AsyncMock(return_value=notebooks)
@ -219,7 +246,9 @@ class TestResolveSourceId:
mock_console.print.assert_called()
@pytest.mark.asyncio
async def test_ambiguous_prefix_raises_exception(self, mock_client_with_sources, sample_sources):
async def test_ambiguous_prefix_raises_exception(
self, mock_client_with_sources, sample_sources
):
"""Ambiguous prefix (matches multiple) raises ClickException."""
mock_client_with_sources.sources.list = AsyncMock(return_value=sample_sources)
@ -318,7 +347,9 @@ class TestResolveSourceIdAmbiguityDisplay:
assert "... and 2 more" in error_msg
@pytest.mark.asyncio
async def test_shows_source_titles_in_ambiguous_error(self, mock_client_with_sources, sample_sources):
async def test_shows_source_titles_in_ambiguous_error(
self, mock_client_with_sources, sample_sources
):
"""Ambiguous error includes source titles."""
mock_client_with_sources.sources.list = AsyncMock(return_value=sample_sources)

View file

@ -1,10 +1,10 @@
"""Tests for session CLI commands (login, use, status, clear)."""
import json
import pytest
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
@ -101,9 +101,7 @@ class TestUseCommand:
)
mock_client_cls.return_value = mock_client
with patch(
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
# Patch in session module where it's imported
@ -131,9 +129,7 @@ class TestUseCommand:
)
mock_client_cls.return_value = mock_client
with patch(
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
# Patch in session module where it's imported
@ -174,9 +170,7 @@ class TestUseCommand:
)
mock_client_cls.return_value = mock_client
with patch(
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
# Patch in session module where it's imported
@ -319,14 +313,10 @@ class TestSessionEdgeCases:
"""Test 'use' command handles API errors gracefully."""
with patch_main_cli_client() as mock_client_cls:
mock_client = create_mock_client()
mock_client.notebooks.get = AsyncMock(
side_effect=Exception("API Error: Rate limited")
)
mock_client.notebooks.get = AsyncMock(side_effect=Exception("API Error: Rate limited"))
mock_client_cls.return_value = mock_client
with patch(
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
) as mock_fetch:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
# Patch in session module where it's imported

View file

@ -1,11 +1,17 @@
"""Tests for skill CLI commands."""
import pytest
from unittest.mock import patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
from .conftest import get_cli_module
# Get the actual skill module (not the click group that shadows it)
skill_module = get_cli_module("skill")
@pytest.fixture
def runner():
@ -20,10 +26,13 @@ class TestSkillInstall:
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
mock_source_content = "---\nname: notebooklm\n---\n# Test"
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest), \
patch("notebooklm.cli.skill.SKILL_DEST_DIR", skill_dest.parent), \
patch("notebooklm.cli.skill.get_skill_source_content", return_value=mock_source_content):
with (
patch.object(skill_module, "SKILL_DEST", skill_dest),
patch.object(skill_module, "SKILL_DEST_DIR", skill_dest.parent),
patch.object(
skill_module, "get_skill_source_content", return_value=mock_source_content
),
):
result = runner.invoke(cli, ["skill", "install"])
assert result.exit_code == 0
@ -34,9 +43,11 @@ class TestSkillInstall:
"""Test error when source file doesn't exist."""
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest), \
patch("notebooklm.cli.skill.SKILL_DEST_DIR", skill_dest.parent), \
patch("notebooklm.cli.skill.get_skill_source_content", return_value=None):
with (
patch.object(skill_module, "SKILL_DEST", skill_dest),
patch.object(skill_module, "SKILL_DEST_DIR", skill_dest.parent),
patch.object(skill_module, "get_skill_source_content", return_value=None),
):
result = runner.invoke(cli, ["skill", "install"])
assert result.exit_code == 1
@ -50,7 +61,7 @@ class TestSkillStatus:
"""Test status when skill is not installed."""
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
with patch.object(skill_module, "SKILL_DEST", skill_dest):
result = runner.invoke(cli, ["skill", "status"])
assert result.exit_code == 0
@ -62,7 +73,7 @@ class TestSkillStatus:
skill_dest.parent.mkdir(parents=True)
skill_dest.write_text("<!-- notebooklm-py v0.1.0 -->\n# Test")
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
with patch.object(skill_module, "SKILL_DEST", skill_dest):
result = runner.invoke(cli, ["skill", "status"])
assert result.exit_code == 0
@ -78,8 +89,10 @@ class TestSkillUninstall:
skill_dest.parent.mkdir(parents=True)
skill_dest.write_text("# Test")
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest), \
patch("notebooklm.cli.skill.SKILL_DEST_DIR", skill_dest.parent):
with (
patch.object(skill_module, "SKILL_DEST", skill_dest),
patch.object(skill_module, "SKILL_DEST_DIR", skill_dest.parent),
):
result = runner.invoke(cli, ["skill", "uninstall"])
assert result.exit_code == 0
@ -89,7 +102,7 @@ class TestSkillUninstall:
"""Test uninstall when skill doesn't exist."""
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
with patch.object(skill_module, "SKILL_DEST", skill_dest):
result = runner.invoke(cli, ["skill", "uninstall"])
assert result.exit_code == 0
@ -105,7 +118,7 @@ class TestSkillShow:
skill_dest.parent.mkdir(parents=True)
skill_dest.write_text("# NotebookLM Skill\nTest content")
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
with patch.object(skill_module, "SKILL_DEST", skill_dest):
result = runner.invoke(cli, ["skill", "show"])
assert result.exit_code == 0
@ -115,7 +128,7 @@ class TestSkillShow:
"""Test show when skill doesn't exist."""
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
with patch.object(skill_module, "SKILL_DEST", skill_dest):
result = runner.invoke(cli, ["skill", "show"])
assert result.exit_code == 0

View file

@ -1,11 +1,10 @@
"""Tests for source CLI commands."""
import json
import pytest
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from click.testing import CliRunner
from notebooklm.notebooklm_cli import cli
@ -60,12 +59,15 @@ class TestSourceList:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_1", title="Test Source", source_type="url", url="https://example.com"),
Source(
id="src_1",
title="Test Source",
source_type="url",
url="https://example.com",
),
]
)
mock_client.notebooks.get = AsyncMock(
return_value=MagicMock(title="Test Notebook")
)
mock_client.notebooks.get = AsyncMock(return_value=MagicMock(title="Test Notebook"))
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
@ -89,7 +91,9 @@ class TestSourceAdd:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.add_url = AsyncMock(
return_value=Source(id="src_new", title="Example", url="https://example.com", source_type="url")
return_value=Source(
id="src_new", title="Example", url="https://example.com", source_type="url"
)
)
mock_client_cls.return_value = mock_client
@ -105,7 +109,12 @@ class TestSourceAdd:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.add_url = AsyncMock(
return_value=Source(id="src_yt", title="YouTube Video", url="https://youtube.com/watch?v=abc", source_type="youtube")
return_value=Source(
id="src_yt",
title="YouTube Video",
url="https://youtube.com/watch?v=abc",
source_type="youtube",
)
)
mock_client_cls.return_value = mock_client
@ -147,7 +156,17 @@ class TestSourceAdd:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli,
["source", "add", "My notes", "--type", "text", "--title", "Custom Title", "-n", "nb_123"],
[
"source",
"add",
"My notes",
"--type",
"text",
"--title",
"Custom Title",
"-n",
"nb_123",
],
)
assert result.exit_code == 0
@ -177,7 +196,9 @@ class TestSourceAdd:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.add_url = AsyncMock(
return_value=Source(id="src_new", title="Example", url="https://example.com", source_type="url")
return_value=Source(
id="src_new", title="Example", url="https://example.com", source_type="url"
)
)
mock_client_cls.return_value = mock_client
@ -203,9 +224,7 @@ class TestSourceGet:
mock_client = create_mock_client()
# Mock sources.list for resolve_source_id
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.get = AsyncMock(
return_value=Source(
@ -213,7 +232,7 @@ class TestSourceGet:
title="Test Source",
source_type="url",
url="https://example.com",
created_at=datetime(2024, 1, 1)
created_at=datetime(2024, 1, 1),
)
)
mock_client_cls.return_value = mock_client
@ -254,18 +273,14 @@ class TestSourceDelete:
mock_client = create_mock_client()
# Mock sources.list for resolve_source_id
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.delete = AsyncMock(return_value=True)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"]
)
result = runner.invoke(cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"])
assert result.exit_code == 0
assert "Deleted source" in result.output
@ -276,18 +291,14 @@ class TestSourceDelete:
mock_client = create_mock_client()
# Mock sources.list for resolve_source_id
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.delete = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"]
)
result = runner.invoke(cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"])
assert result.exit_code == 0
assert "Delete may have failed" in result.output
@ -304,9 +315,7 @@ class TestSourceRename:
mock_client = create_mock_client()
# Mock sources.list for resolve_source_id
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Old Title", source_type="url")
]
return_value=[Source(id="src_123", title="Old Title", source_type="url")]
)
mock_client.sources.rename = AsyncMock(
return_value=Source(id="src_123", title="New Title", source_type="url")
@ -335,9 +344,7 @@ class TestSourceRefresh:
mock_client = create_mock_client()
# Mock sources.list for resolve_source_id
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Original Source", source_type="url")
]
return_value=[Source(id="src_123", title="Original Source", source_type="url")]
)
mock_client.sources.refresh = AsyncMock(
return_value=Source(id="src_123", title="Refreshed Source", source_type="url")
@ -356,9 +363,7 @@ class TestSourceRefresh:
mock_client = create_mock_client()
# Mock sources.list for resolve_source_id
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Original Source", source_type="url")
]
return_value=[Source(id="src_123", title="Original Source", source_type="url")]
)
mock_client.sources.refresh = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
@ -405,7 +410,17 @@ class TestSourceAddDrive:
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli, ["source", "add-drive", "file_id", "PDF Title", "--mime-type", "pdf", "-n", "nb_123"]
cli,
[
"source",
"add-drive",
"file_id",
"PDF Title",
"--mime-type",
"pdf",
"-n",
"nb_123",
],
)
assert result.exit_code == 0
@ -426,14 +441,12 @@ class TestSourceGuide:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.get_guide = AsyncMock(
return_value={
"summary": "This is a **test** summary about AI.",
"keywords": ["AI", "machine learning", "data science"]
"keywords": ["AI", "machine learning", "data science"],
}
)
mock_client_cls.return_value = mock_client
@ -452,13 +465,9 @@ class TestSourceGuide:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
)
mock_client.sources.get_guide = AsyncMock(
return_value={"summary": "", "keywords": []}
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.get_guide = AsyncMock(return_value={"summary": "", "keywords": []})
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
@ -472,21 +481,18 @@ class TestSourceGuide:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.get_guide = AsyncMock(
return_value={
"summary": "Test summary",
"keywords": ["keyword1", "keyword2"]
}
return_value={"summary": "Test summary", "keywords": ["keyword1", "keyword2"]}
)
mock_client_cls.return_value = mock_client
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(cli, ["source", "guide", "src_123", "-n", "nb_123", "--json"])
result = runner.invoke(
cli, ["source", "guide", "src_123", "-n", "nb_123", "--json"]
)
assert result.exit_code == 0
data = json.loads(result.output)
@ -499,9 +505,7 @@ class TestSourceGuide:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.get_guide = AsyncMock(
return_value={"summary": "Summary without keywords", "keywords": []}
@ -522,9 +526,7 @@ class TestSourceGuide:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.get_guide = AsyncMock(
return_value={"summary": "", "keywords": ["AI", "ML", "Data"]}
@ -552,9 +554,7 @@ class TestSourceStale:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.check_freshness = AsyncMock(return_value=False) # Not fresh = stale
mock_client_cls.return_value = mock_client
@ -572,9 +572,7 @@ class TestSourceStale:
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.list = AsyncMock(
return_value=[
Source(id="src_123", title="Test Source", source_type="url")
]
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
)
mock_client.sources.check_freshness = AsyncMock(return_value=True) # Fresh
mock_client_cls.return_value = mock_client

View file

@ -1,15 +1,16 @@
"""Unit tests for new API coverage features."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from notebooklm import NotebookLMClient
from notebooklm.auth import AuthTokens
from notebooklm.rpc.types import (
RPCMethod,
ChatGoal,
ChatResponseLength,
DriveMimeType,
RPCMethod,
)

View file

@ -1,11 +1,11 @@
"""Unit tests for artifact download methods."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import tempfile
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from notebooklm import NotebookLMClient
from notebooklm._artifacts import ArtifactsAPI
from notebooklm.auth import AuthTokens
@ -42,26 +42,34 @@ class TestDownloadAudio:
"""Test successful audio download."""
api, mock_core = mock_artifacts_api
# Mock artifact list response - type 1 (audio), status 3 (completed)
mock_core.rpc_call.return_value = [[
mock_core.rpc_call.return_value = [
[
"audio_001", # id
"Audio Title", # title
1, # type (audio)
None, # ?
3, # status (completed)
None, # ?
[None, None, None, None, None, [ # metadata[6][5] = media list
["https://example.com/audio.mp4", None, "audio/mp4"]
]],
[
"audio_001", # id
"Audio Title", # title
1, # type (audio)
None, # ?
3, # status (completed)
None, # ?
[
None,
None,
None,
None,
None,
[ # metadata[6][5] = media list
["https://example.com/audio.mp4", None, "audio/mp4"]
],
],
]
]
]]
]
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "audio.mp4")
with patch.object(
api, '_download_url',
new_callable=AsyncMock, return_value=output_path
api, "_download_url", new_callable=AsyncMock, return_value=output_path
):
result = await api.download_audio("nb_123", output_path)
@ -80,22 +88,20 @@ class TestDownloadAudio:
async def test_download_audio_specific_id_not_found(self, mock_artifacts_api):
"""Test error when specific audio ID not found."""
api, mock_core = mock_artifacts_api
mock_core.rpc_call.return_value = [[
["other_id", "Audio", 1, None, 3, None, [None] * 6]
]]
mock_core.rpc_call.return_value = [[["other_id", "Audio", 1, None, 3, None, [None] * 6]]]
with pytest.raises(ValueError, match="Audio artifact audio_001 not found"):
await api.download_audio(
"nb_123", "/tmp/audio.mp4", artifact_id="audio_001"
)
await api.download_audio("nb_123", "/tmp/audio.mp4", artifact_id="audio_001")
@pytest.mark.asyncio
async def test_download_audio_invalid_metadata(self, mock_artifacts_api):
"""Test error on invalid metadata structure."""
api, mock_core = mock_artifacts_api
mock_core.rpc_call.return_value = [[
["audio_001", "Audio", 1, None, 3, None, "not_a_list"] # metadata should be list
]]
mock_core.rpc_call.return_value = [
[
["audio_001", "Audio", 1, None, 3, None, "not_a_list"] # metadata should be list
]
]
with pytest.raises(ValueError, match="Invalid audio metadata|Failed to parse"):
await api.download_audio("nb_123", "/tmp/audio.mp4")
@ -113,16 +119,24 @@ class TestDownloadVideo:
output_path = os.path.join(tmpdir, "video.mp4")
# Patch _list_raw to return video artifact data
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
# Type 3 (video), status 3 (completed), metadata at index 8
mock_list.return_value = [[
"video_001", "Video Title", 3, None, 3, None, None, None,
[[["https://example.com/video.mp4", 4, "video/mp4"]]]
]]
mock_list.return_value = [
[
"video_001",
"Video Title",
3,
None,
3,
None,
None,
None,
[[["https://example.com/video.mp4", 4, "video/mp4"]]],
]
]
with patch.object(
api, '_download_url',
new_callable=AsyncMock, return_value=output_path
api, "_download_url", new_callable=AsyncMock, return_value=output_path
):
result = await api.download_video("nb_123", output_path)
@ -133,7 +147,7 @@ class TestDownloadVideo:
"""Test error when no video artifact exists."""
api, mock_core = mock_artifacts_api
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
mock_list.return_value = []
with pytest.raises(ValueError, match="No completed video"):
@ -144,15 +158,11 @@ class TestDownloadVideo:
"""Test error when specific video ID not found."""
api, mock_core = mock_artifacts_api
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
mock_list.return_value = [
["other_id", "Video", 3, None, 3, None, None, None, []]
]
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
mock_list.return_value = [["other_id", "Video", 3, None, 3, None, None, None, []]]
with pytest.raises(ValueError, match="Video artifact video_001 not found"):
await api.download_video(
"nb_123", "/tmp/video.mp4", artifact_id="video_001"
)
await api.download_video("nb_123", "/tmp/video.mp4", artifact_id="video_001")
class TestDownloadInfographic:
@ -167,17 +177,25 @@ class TestDownloadInfographic:
output_path = os.path.join(tmpdir, "infographic.png")
# Patch _list_raw to return infographic data
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
# Type 7 (infographic), status 3, metadata with nested URL structure
mock_list.return_value = [[
"infographic_001", "Infographic Title", 7, None, 3,
None, None, None, None,
[[], [], [[None, ["https://example.com/infographic.png"]]]]
]]
mock_list.return_value = [
[
"infographic_001",
"Infographic Title",
7,
None,
3,
None,
None,
None,
None,
[[], [], [[None, ["https://example.com/infographic.png"]]]],
]
]
with patch.object(
api, '_download_url',
new_callable=AsyncMock, return_value=output_path
api, "_download_url", new_callable=AsyncMock, return_value=output_path
):
result = await api.download_infographic("nb_123", output_path)
@ -188,7 +206,7 @@ class TestDownloadInfographic:
"""Test error when no infographic artifact exists."""
api, mock_core = mock_artifacts_api
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
mock_list.return_value = []
with pytest.raises(ValueError, match="No completed infographic"):
@ -208,23 +226,24 @@ class TestDownloadSlideDeck:
# Patch _list_raw to return slide deck artifact data
# Structure: artifact[16] = [config, title, slides_list, pdf_url]
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
# Create artifact with 17+ elements, type 8 (slide deck), status 3
artifact = ["slide_001", "Slide Deck Title", 8, None, 3]
# Pad to index 16
artifact.extend([None] * 11)
# Index 16: metadata with PDF URL at position 3
artifact.append([
["config"],
"Slide Deck Title",
[["slide1"], ["slide2"]], # slides_list
"https://contribution.usercontent.google.com/download?filename=test.pdf"
])
artifact.append(
[
["config"],
"Slide Deck Title",
[["slide1"], ["slide2"]], # slides_list
"https://contribution.usercontent.google.com/download?filename=test.pdf",
]
)
mock_list.return_value = [artifact]
with patch.object(
api, '_download_url',
new_callable=AsyncMock, return_value=output_path
api, "_download_url", new_callable=AsyncMock, return_value=output_path
):
result = await api.download_slide_deck("nb_123", output_path)
@ -235,7 +254,7 @@ class TestDownloadSlideDeck:
"""Test error when no slide deck artifact exists."""
api, mock_core = mock_artifacts_api
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
mock_list.return_value = []
with pytest.raises(ValueError, match="No completed slide"):
@ -246,7 +265,7 @@ class TestDownloadSlideDeck:
"""Test error when specific slide deck ID not found."""
api, mock_core = mock_artifacts_api
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
# Need at least 17 elements for valid structure
artifact = ["other_id", "Slides", 8, None, 3]
artifact.extend([None] * 11)
@ -261,7 +280,7 @@ class TestDownloadSlideDeck:
"""Test error on invalid metadata structure."""
api, mock_core = mock_artifacts_api
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
# Create artifact with invalid metadata (less than 4 elements)
artifact = ["slide_001", "Slides", 8, None, 3]
artifact.extend([None] * 11)
@ -284,11 +303,13 @@ class TestMindMapGeneration:
# _get_source_ids response
[None, None, None, None, None, [[["src_001"]]]],
# generate_mind_map response
[[
'{"nodes": [{"id": "1", "text": "Root"}]}', # JSON string
None,
["note_123"], # note info (not used anymore, note is created explicitly)
]],
[
[
'{"nodes": [{"id": "1", "text": "Root"}]}', # JSON string
None,
["note_123"], # note info (not used anymore, note is created explicitly)
]
],
]
result = await api.generate_mind_map("nb_123")
@ -304,11 +325,13 @@ class TestMindMapGeneration:
api, mock_core = mock_artifacts_api
mock_core.rpc_call.side_effect = [
[None, None, None, None, None, [[["src_001"]]]],
[[
{"nodes": [{"id": "1"}]}, # Already a dict
None,
["note_456"], # note info (not used anymore)
]],
[
[
{"nodes": [{"id": "1"}]}, # Already a dict
None,
["note_456"], # note info (not used anymore)
]
],
]
result = await api.generate_mind_map("nb_123")
@ -359,12 +382,10 @@ class TestDownloadUrl:
# Mock load_httpx_cookies to avoid requiring real auth files
mock_cookies = MagicMock()
with patch.object(real_httpx, 'AsyncClient', return_value=mock_client), \
patch('notebooklm._artifacts.load_httpx_cookies', return_value=mock_cookies):
result = await api._download_url(
"https://other.example.com/file.mp4", output_path
)
with (
patch.object(real_httpx, "AsyncClient", return_value=mock_client),
patch("notebooklm._artifacts.load_httpx_cookies", return_value=mock_cookies),
):
result = await api._download_url("https://other.example.com/file.mp4", output_path)
assert result == output_path

View file

@ -1,8 +1,9 @@
"""Tests for authentication module."""
import pytest
import json
from pathlib import Path
import pytest
from pytest_httpx import HTTPXMock
from notebooklm.auth import (
@ -10,9 +11,9 @@ from notebooklm.auth import (
extract_cookies_from_storage,
extract_csrf_from_html,
extract_session_id_from_html,
fetch_tokens,
load_auth_from_storage,
load_httpx_cookies,
fetch_tokens,
)
@ -254,7 +255,9 @@ class TestLoadAuthFromEnvVar:
# Set NOTEBOOKLM_HOME to tmp_path and create a file there
monkeypatch.setenv("NOTEBOOKLM_HOME", str(tmp_path))
file_storage = {"cookies": [{"name": "SID", "value": "from_home_file", "domain": ".google.com"}]}
file_storage = {
"cookies": [{"name": "SID", "value": "from_home_file", "domain": ".google.com"}]
}
storage_file = tmp_path / "storage_state.json"
storage_file.write_text(json.dumps(file_storage))
@ -266,28 +269,36 @@ class TestLoadAuthFromEnvVar:
"""Test that empty string NOTEBOOKLM_AUTH_JSON raises ValueError."""
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", "")
with pytest.raises(ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"):
with pytest.raises(
ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"
):
load_auth_from_storage()
def test_env_var_whitespace_only_raises_value_error(self, monkeypatch):
"""Test that whitespace-only NOTEBOOKLM_AUTH_JSON raises ValueError."""
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", " \n\t ")
with pytest.raises(ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"):
with pytest.raises(
ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"
):
load_auth_from_storage()
def test_env_var_missing_cookies_key_raises_value_error(self, monkeypatch):
"""Test that NOTEBOOKLM_AUTH_JSON without 'cookies' key raises ValueError."""
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", '{"origins": []}')
with pytest.raises(ValueError, match="must contain valid Playwright storage state with a 'cookies' key"):
with pytest.raises(
ValueError, match="must contain valid Playwright storage state with a 'cookies' key"
):
load_auth_from_storage()
def test_env_var_non_dict_raises_value_error(self, monkeypatch):
"""Test that non-dict NOTEBOOKLM_AUTH_JSON raises ValueError."""
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", '["not", "a", "dict"]')
with pytest.raises(ValueError, match="must contain valid Playwright storage state with a 'cookies' key"):
with pytest.raises(
ValueError, match="must contain valid Playwright storage state with a 'cookies' key"
):
load_auth_from_storage()
@ -327,7 +338,9 @@ class TestLoadHttpxCookiesWithEnvVar:
"""Test that empty string NOTEBOOKLM_AUTH_JSON raises ValueError."""
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", "")
with pytest.raises(ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"):
with pytest.raises(
ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"
):
load_httpx_cookies()
def test_env_var_missing_required_cookies_raises(self, monkeypatch):
@ -640,7 +653,6 @@ class TestDefaultStoragePath:
def test_default_storage_path_is_correct(self):
"""Test DEFAULT_STORAGE_PATH constant is defined correctly."""
from notebooklm.auth import DEFAULT_STORAGE_PATH
from pathlib import Path
assert DEFAULT_STORAGE_PATH is not None
assert isinstance(DEFAULT_STORAGE_PATH, Path)

View file

@ -1,14 +1,12 @@
"""Tests for NotebookLMClient class."""
import json
import pytest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pytest_httpx import HTTPXMock
from notebooklm.client import NotebookLMClient
from notebooklm.auth import AuthTokens
from notebooklm.client import NotebookLMClient
@pytest.fixture
@ -172,7 +170,7 @@ class TestRefreshAuth:
client = NotebookLMClient(mock_auth)
# Mock the homepage response with new tokens
html = '''
html = """
<html>
<script>
window.WIZ_global_data = {
@ -181,7 +179,7 @@ class TestRefreshAuth:
};
</script>
</html>
'''
"""
httpx_mock.add_response(
url="https://notebooklm.google.com/",
content=html.encode(),
@ -208,7 +206,7 @@ class TestRefreshAuth:
# The refresh_auth checks if "accounts.google.com" is in the final URL
# We can't easily mock a real redirect with httpx, so we test the URL check
# by providing a response that doesn't contain the expected tokens
html = '<html><body>Please sign in</body></html>' # No tokens
html = "<html><body>Please sign in</body></html>" # No tokens
httpx_mock.add_response(
url="https://notebooklm.google.com/",
content=html.encode(),

View file

@ -1,10 +1,10 @@
"""Tests for conversation functionality."""
import pytest
from unittest.mock import AsyncMock, patch
import json
from notebooklm import NotebookLMClient, AskResult
import pytest
from notebooklm import AskResult, NotebookLMClient
from notebooklm.auth import AuthTokens

View file

@ -1,15 +1,16 @@
"""Unit tests for RPC response decoder."""
import pytest
import json
import pytest
from notebooklm.rpc.decoder import (
strip_anti_xssi,
parse_chunked_response,
extract_rpc_result,
RPCError,
collect_rpc_ids,
decode_response,
RPCError,
extract_rpc_result,
parse_chunked_response,
strip_anti_xssi,
)
from notebooklm.rpc.types import RPCMethod
@ -192,9 +193,7 @@ class TestExtractRPCResult:
def test_user_displayable_error_sets_code(self):
"""Test UserDisplayableError sets code to USER_DISPLAYABLE_ERROR."""
error_info = [8, None, [["UserDisplayableError", []]]]
chunks = [
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]
]
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]]
with pytest.raises(RPCError) as exc_info:
extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
@ -203,18 +202,14 @@ class TestExtractRPCResult:
def test_null_result_without_error_info_returns_none(self):
"""Test null result without UserDisplayableError returns None normally."""
chunks = [
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, None]
]
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, None]]
result = extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
assert result is None
def test_null_result_with_non_error_info_returns_none(self):
"""Test null result with non-error data at index 5 returns None."""
chunks = [
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, [1, 2, 3]]
]
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, [1, 2, 3]]]
result = extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
assert result is None
@ -229,9 +224,7 @@ class TestExtractRPCResult:
"type": "type.googleapis.com/google.internal.labs.tailwind.orchestration.v1.UserDisplayableError",
"details": {"code": 1},
}
chunks = [
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]
]
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]]
with pytest.raises(RPCError, match="rate limiting"):
extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
@ -267,9 +260,7 @@ class TestDecodeResponse:
def test_decode_complex_nested_data(self):
"""Test decoding complex nested data structures."""
data = {
"notebooks": [{"id": "nb1", "title": "Test", "sources": [{"id": "s1"}]}]
}
data = {"notebooks": [{"id": "nb1", "title": "Test", "sources": [{"id": "s1"}]}]}
inner = json.dumps(data)
chunk = json.dumps(["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, inner, None, None])
raw_response = f")]}}'\n{len(chunk)}\n{chunk}\n"

View file

@ -1,7 +1,8 @@
"""Tests for download helper functions."""
import pytest
from notebooklm.cli.download_helpers import select_artifact, artifact_title_to_filename
from notebooklm.cli.download_helpers import artifact_title_to_filename, select_artifact
class TestSelectArtifact:

View file

@ -1,10 +1,8 @@
"""Unit tests for RPC request encoder."""
import pytest
import json
from urllib.parse import unquote
from notebooklm.rpc.encoder import encode_rpc_request, build_request_body, build_url_params
from notebooklm.rpc.encoder import build_request_body, build_url_params, encode_rpc_request
from notebooklm.rpc.types import RPCMethod
@ -134,28 +132,21 @@ class TestBuildUrlParams:
def test_with_source_path(self):
"""Test URL params with custom source path."""
result = build_url_params(
RPCMethod.GET_NOTEBOOK,
source_path="/notebook/abc123"
)
result = build_url_params(RPCMethod.GET_NOTEBOOK, source_path="/notebook/abc123")
assert result["rpcids"] == RPCMethod.GET_NOTEBOOK.value
assert result["source-path"] == "/notebook/abc123"
def test_with_session_id(self):
"""Test URL params with session ID."""
result = build_url_params(
RPCMethod.LIST_NOTEBOOKS,
session_id="session_12345"
)
result = build_url_params(RPCMethod.LIST_NOTEBOOKS, session_id="session_12345")
assert result["f.sid"] == "session_12345"
def test_with_build_label(self):
"""Test URL params with build label."""
result = build_url_params(
RPCMethod.LIST_NOTEBOOKS,
bl="boq_labs-tailwind-frontend_20250101"
RPCMethod.LIST_NOTEBOOKS, bl="boq_labs-tailwind-frontend_20250101"
)
assert result["bl"] == "boq_labs-tailwind-frontend_20250101"
@ -166,7 +157,7 @@ class TestBuildUrlParams:
RPCMethod.CREATE_NOTEBOOK,
source_path="/notebook/xyz789",
session_id="sess_abc",
bl="build_label_123"
bl="build_label_123",
)
assert result["rpcids"] == RPCMethod.CREATE_NOTEBOOK.value

View file

@ -1,10 +1,10 @@
"""Unit tests for NotesAPI private helpers and edge cases."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from unittest.mock import MagicMock, AsyncMock
from notebooklm._notes import NotesAPI
from notebooklm.types import Note
@pytest.fixture

View file

@ -1,16 +1,15 @@
"""Tests for path resolution module."""
import os
import pytest
from pathlib import Path
from unittest.mock import patch
from notebooklm.paths import (
get_home_dir,
get_storage_path,
get_context_path,
get_browser_profile_dir,
get_context_path,
get_home_dir,
get_path_info,
get_storage_path,
)

View file

@ -1,8 +1,10 @@
"""Tests for research functionality."""
import pytest
import json
import re
import pytest
from notebooklm import NotebookLMClient
from notebooklm.auth import AuthTokens
from notebooklm.rpc import RPCMethod
@ -55,9 +57,7 @@ class TestResearch:
]
response_json = json.dumps([[["task_123", task_info]]])
chunk = json.dumps(
["wrb.fr", RPCMethod.POLL_RESEARCH.value, response_json, None, None]
)
chunk = json.dumps(["wrb.fr", RPCMethod.POLL_RESEARCH.value, response_json, None, None])
response_body = f")]}}'\n{len(chunk)}\n{chunk}\n"
httpx_mock.add_response(content=response_body.encode(), method="POST")
@ -73,9 +73,7 @@ class TestResearch:
@pytest.mark.asyncio
async def test_import_research(self, auth_tokens, httpx_mock):
response_json = json.dumps([[[["src_new"], "Imported Title"]]])
chunk = json.dumps(
["wrb.fr", RPCMethod.IMPORT_RESEARCH.value, response_json, None, None]
)
chunk = json.dumps(["wrb.fr", RPCMethod.IMPORT_RESEARCH.value, response_json, None, None])
response_body = f")]}}'\n{len(chunk)}\n{chunk}\n"
httpx_mock.add_response(content=response_body.encode(), method="POST")

View file

@ -1,11 +1,10 @@
"""Unit tests for RPC types and constants."""
import pytest
from notebooklm.rpc.types import (
RPCMethod,
StudioContentType,
BATCHEXECUTE_URL,
QUERY_URL,
RPCMethod,
StudioContentType,
)
@ -13,8 +12,7 @@ class TestRPCConstants:
def test_batchexecute_url(self):
"""Test batchexecute URL is correct."""
assert (
BATCHEXECUTE_URL
== "https://notebooklm.google.com/_/LabsTailwindUi/data/batchexecute"
BATCHEXECUTE_URL == "https://notebooklm.google.com/_/LabsTailwindUi/data/batchexecute"
)
def test_query_url(self):

View file

@ -1,18 +1,19 @@
"""Unit tests for source status and polling functionality."""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from notebooklm._sources import SourcesAPI
from notebooklm.types import (
Source,
SourceStatus,
SourceError,
SourceProcessingError,
SourceTimeoutError,
SourceNotFoundError,
SourceProcessingError,
SourceStatus,
SourceTimeoutError,
)
from notebooklm._sources import SourcesAPI
class TestSourceStatus:
@ -27,7 +28,7 @@ class TestSourceStatus:
def test_status_is_int_enum(self):
"""Test that SourceStatus values can be compared with ints."""
assert SourceStatus.READY == 2
assert 2 == SourceStatus.READY
assert SourceStatus.READY == 2
class TestSourceStatusProperties:
@ -235,9 +236,7 @@ class TestWaitForSources:
raise SourceNotFoundError(source_id)
with patch.object(sources_api, "wait_until_ready", side_effect=mock_wait):
results = await sources_api.wait_for_sources(
"nb_1", ["src_1", "src_2"], timeout=10.0
)
results = await sources_api.wait_for_sources("nb_1", ["src_1", "src_2"], timeout=10.0)
assert len(results) == 2
assert all(s.is_ready for s in results)
@ -253,6 +252,4 @@ class TestWaitForSources:
with patch.object(sources_api, "wait_until_ready", side_effect=mock_wait):
with pytest.raises(SourceProcessingError):
await sources_api.wait_for_sources(
"nb_1", ["src_1", "src_2"], timeout=10.0
)
await sources_api.wait_for_sources("nb_1", ["src_1", "src_2"], timeout=10.0)

View file

@ -1,8 +1,8 @@
"""Unit tests for SourcesAPI file upload pipeline and YouTube detection."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from unittest.mock import MagicMock, AsyncMock, patch, mock_open
from pathlib import Path
from notebooklm._sources import SourcesAPI
@ -223,7 +223,9 @@ class TestStartResumableUpload:
assert body["SOURCE_ID"] == "src_abc"
@pytest.mark.asyncio
async def test_start_resumable_upload_raises_on_missing_url_header(self, sources_api, mock_core):
async def test_start_resumable_upload_raises_on_missing_url_header(
self, sources_api, mock_core
):
"""Test that missing upload URL header raises ValueError."""
mock_response = MagicMock()
mock_response.headers = {} # No x-goog-upload-url
@ -286,7 +288,9 @@ class TestUploadFileStreaming:
mock_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_upload_file_streaming_includes_correct_headers(self, sources_api, mock_core, tmp_path):
async def test_upload_file_streaming_includes_correct_headers(
self, sources_api, mock_core, tmp_path
):
"""Test that streaming upload includes correct headers."""
test_file = tmp_path / "test.txt"
test_file.write_bytes(b"content")
@ -325,9 +329,7 @@ class TestUploadFileStreaming:
mock_client.post.return_value = mock_response
mock_client_cls.return_value = mock_client
await sources_api._upload_file_streaming(
"https://upload.example.com", test_file
)
await sources_api._upload_file_streaming("https://upload.example.com", test_file)
call_kwargs = mock_client.post.call_args[1]
# Content should be a generator, not bytes
@ -337,7 +339,9 @@ class TestUploadFileStreaming:
assert b"".join(chunks) == test_content
@pytest.mark.asyncio
async def test_upload_file_streaming_raises_on_http_error(self, sources_api, mock_core, tmp_path):
async def test_upload_file_streaming_raises_on_http_error(
self, sources_api, mock_core, tmp_path
):
"""Test that HTTP error raises exception."""
import httpx
@ -354,9 +358,7 @@ class TestUploadFileStreaming:
mock_client_cls.return_value = mock_client
with pytest.raises(httpx.HTTPStatusError):
await sources_api._upload_file_streaming(
"https://upload.example.com", test_file
)
await sources_api._upload_file_streaming("https://upload.example.com", test_file)
# =============================================================================

View file

@ -1,20 +1,18 @@
"""Unit tests for types module dataclasses and parsing."""
import pytest
from datetime import datetime
from notebooklm.types import (
Notebook,
NotebookDescription,
SuggestedTopic,
Source,
Artifact,
GenerationStatus,
ReportSuggestion,
Note,
ConversationTurn,
AskResult,
ChatMode,
ConversationTurn,
GenerationStatus,
Note,
Notebook,
NotebookDescription,
ReportSuggestion,
Source,
)
@ -110,7 +108,13 @@ class TestSource:
def test_from_api_response_nested_format(self):
"""Test parsing medium nested format."""
data = [[["src_456"], "Nested Source", [None, None, None, None, None, None, None, ["https://example.com"]]]]
data = [
[
["src_456"],
"Nested Source",
[None, None, None, None, None, None, None, ["https://example.com"]],
]
]
source = Source.from_api_response(data)
assert source.id == "src_456"
@ -119,7 +123,15 @@ class TestSource:
def test_from_api_response_deeply_nested(self):
"""Test parsing deeply nested format."""
data = [[[["src_789"], "Deep Source", [None, None, None, None, None, None, None, ["https://deep.example.com"]]]]]
data = [
[
[
["src_789"],
"Deep Source",
[None, None, None, None, None, None, None, ["https://deep.example.com"]],
]
]
]
source = Source.from_api_response(data)
assert source.id == "src_789"
@ -128,14 +140,30 @@ class TestSource:
def test_from_api_response_youtube_url(self):
"""Test that YouTube URLs are detected."""
data = [[[["src_yt"], "YouTube Video", [None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]]]]]
data = [
[
[
["src_yt"],
"YouTube Video",
[None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]],
]
]
]
source = Source.from_api_response(data)
assert source.source_type == "youtube"
def test_from_api_response_youtu_be_short_url(self):
"""Test that youtu.be short URLs are detected."""
data = [[[["src_yt2"], "Short Video", [None, None, None, None, None, None, None, ["https://youtu.be/abc"]]]]]
data = [
[
[
["src_yt2"],
"Short Video",
[None, None, None, None, None, None, None, ["https://youtu.be/abc"]],
]
]
]
source = Source.from_api_response(data)
assert source.source_type == "youtube"
@ -165,7 +193,24 @@ class TestArtifact:
def test_from_api_response_with_timestamp(self):
"""Test parsing artifact with timestamp."""
ts = 1704067200
data = ["art_123", "Audio", 1, None, 3, None, None, None, None, None, None, None, None, None, None, [ts]]
data = [
"art_123",
"Audio",
1,
None,
3,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
[ts],
]
artifact = Artifact.from_api_response(data)
assert artifact.created_at is not None

View file

@ -1,8 +1,9 @@
"""Unit tests for YouTube URL extraction."""
import pytest
from unittest.mock import MagicMock
import pytest
from notebooklm import NotebookLMClient